[Mlir-commits] [mlir] [mlir][tosa] Add folding for TOSA ArgMax operator (PR #88871)
Dmitrii Agibov
llvmlistbot at llvm.org
Wed Apr 17 02:55:17 PDT 2024
https://github.com/d-agbv updated https://github.com/llvm/llvm-project/pull/88871
>From 358038968748e96c35ea437c969bee744fd7eb66 Mon Sep 17 00:00:00 2001
From: Dmitrii Agibov <dmitrii.agibov at arm.com>
Date: Tue, 16 Apr 2024 11:44:25 +0100
Subject: [PATCH 1/3] [mlir][tosa] Add folding for TOSA ArgMax operator
TOSA ArgMax operator could be folded into a constant tensor
filled with zeros when dimension of the selected axis
equals one.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 14 ++++++++++
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 26 ++++++++++++++-----
3 files changed, 35 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 306e4a43952088..dde17e2dc8924d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -49,6 +49,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
Tosa_Tensor: $output
);
+ let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index d23c9fe824c94a..53ae7211f987e2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -507,6 +507,20 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
resultTy);
}
+OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
+ auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+ auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
+ !outputTy.hasStaticShape())
+ return {};
+
+ if (inputTy.getDimSize(getAxis()) == 1) {
+ return DenseElementsAttr::get(outputTy, 0);
+ }
+
+ return {};
+}
+
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index de752f31fcbaa1..1513d4f772330d 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -3,6 +3,20 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
+// CHECK-LABEL: @argmax_fold_dim_1
+func.func @argmax_fold_dim_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
+ // CHECK: "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ %0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<2x1x3xf32>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
+
+// CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_1
+func.func @argmax_dynamic_shape_no_fold_dim_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> {
+ // CHECK: tosa.argmax
+ %0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<?x1x3xf32>) -> tensor<?x3xi32>
+ return %0 : tensor<?x3xi32>
+}
+
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
@@ -1100,9 +1114,9 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
// AGGRESIVE-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
// AGGRESIVE-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
// AGGRESIVE-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
// AGGRESIVE: %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
- // AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// AGGRESIVE: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// AGGRESIVE: return %[[VAL_6]] : tensor<2x3xi32>
@@ -1110,18 +1124,18 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
- // CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
// CHECK: %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
- // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: return %[[VAL_6]] : tensor<2x3xi32>
%const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
%const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
%reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
- %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ %argmax0 = tosa.argmax %reduce0 {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
%argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
- %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ %res0 = tosa.add %argmax0, %const1 : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %res1 : tensor<2x3xi32>
}
>From a55e1fda3c392708c96e241c712e1ab8430e5b19 Mon Sep 17 00:00:00 2001
From: Dmitrii Agibov <dmitrii.agibov at arm.com>
Date: Wed, 17 Apr 2024 10:11:23 +0100
Subject: [PATCH 2/3] [mlir][tosa] Add folding for TOSA ArgMax operator
Update after review:
- Make code formatting consistent
- Rename LIT test
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 3 +--
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 4 ++--
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 53ae7211f987e2..c8bf4c526b239f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -514,9 +514,8 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
!outputTy.hasStaticShape())
return {};
- if (inputTy.getDimSize(getAxis()) == 1) {
+ if (inputTy.getDimSize(getAxis()) == 1)
return DenseElementsAttr::get(outputTy, 0);
- }
return {};
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 1513d4f772330d..35ba00adfe6462 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -3,8 +3,8 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
-// CHECK-LABEL: @argmax_fold_dim_1
-func.func @argmax_fold_dim_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
+// CHECK-LABEL: @armax_fold_dim_size_1
+func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
// CHECK: "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<2x1x3xf32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
>From 4451365eba4c5968c264cebeaa7e45e6662ccf17 Mon Sep 17 00:00:00 2001
From: Dmitrii Agibov <dmitrii.agibov at arm.com>
Date: Wed, 17 Apr 2024 10:53:19 +0100
Subject: [PATCH 3/3] [mlir][tosa] Add folding for TOSA ArgMax operator
Update after review:
- Rename LIT test
---
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 35ba00adfe6462..c9c60a94bf9ed0 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -10,8 +10,8 @@ func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
return %0 : tensor<2x3xi32>
}
-// CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_1
-func.func @argmax_dynamic_shape_no_fold_dim_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> {
+// CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1
+func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> {
// CHECK: tosa.argmax
%0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<?x1x3xf32>) -> tensor<?x3xi32>
return %0 : tensor<?x3xi32>
More information about the Mlir-commits
mailing list