[Mlir-commits] [mlir] 295a6ed - [TOSA] Fold consecutive concats on same axis
Dominik Montada
llvmlistbot at llvm.org
Wed May 24 03:02:38 PDT 2023
Author: Dominik Montada
Date: 2023-05-24T10:00:56Z
New Revision: 295a6ed5d54aca2d923b76b6388f9732eb37f548
URL: https://github.com/llvm/llvm-project/commit/295a6ed5d54aca2d923b76b6388f9732eb37f548
DIFF: https://github.com/llvm/llvm-project/commit/295a6ed5d54aca2d923b76b6388f9732eb37f548.diff
LOG: [TOSA] Fold consecutive concats on same axis
Consecutive concats that happen on the same axis can be folded into a
single, bigger concat. This patch implements this folding by
implementing the tosa::ConcatOp::fold method.
Differential Revision: https://reviews.llvm.org/D151210
Added:
mlir/test/Dialect/Tosa/fold_concats.mlir
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f064e7a180441..359690db1eb7b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1500,6 +1500,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
);
let hasCanonicalizer = 1;
+ let hasFolder = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 289f19559ad64..a70749e89a887 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1037,3 +1037,37 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
return {};
}
+
+OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
+ // Fold consecutive concats on the same axis into a single op.
+ // Keep track of the operands so we are able to construct a new concat
+ // later. Conservatively assume that we double the number of operands when
+ // folding
+ SmallVector<Value, 8> concatOperands;
+ concatOperands.reserve(2 * getNumOperands());
+
+ // Find all operands that are foldable concats
+ bool foundFoldableConcat = false;
+ for (Value operand : getOperands()) {
+ concatOperands.emplace_back(operand);
+
+ auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
+ if (!producer)
+ continue;
+
+ // Not foldable if axes are not the same
+ if (getAxis() != producer.getAxis())
+ continue;
+
+ // Replace the original operand with all incoming operands
+ foundFoldableConcat = true;
+ concatOperands.pop_back();
+ llvm::append_range(concatOperands, producer->getOperands());
+ }
+
+ if (!foundFoldableConcat)
+ return {};
+
+ getOperation()->setOperands(concatOperands);
+ return getResult();
+}
diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir
new file mode 100644
index 0000000000000..781d527cc32f2
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/fold_concats.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
+
+func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
+ %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ return %0 : tensor<1x2x7x7xf32>
+}
+
+// CHECK-LABEL: func.func @single_concat(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
+// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32>
+// CHECK: }
+
+// -----
+
+func.func @concat_
diff erent_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
+ %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = "tosa.concat"(%0, %0) {axis = 0} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+ return %1 : tensor<2x2x7x7xf32>
+}
+
+// CHECK-LABEL: func.func @concat_
diff erent_axis(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
+// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) <{axis = 0 : i64}> : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32>
+// CHECK: }
+
+// -----
+
+func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+ %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+ %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+ return %1 : tensor<1x4x7x7xf32>
+}
+
+// CHECK-LABEL: func.func @fold_concats(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
+// CHECK: }
+
+// -----
+
+func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
+ %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+ %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+ %2 = "tosa.concat"(%1, %1) {axis = 1} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
+ return %2 : tensor<1x8x7x7xf32>
+}
+
+// CHECK-LABEL: func.func @nested_fold(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
+// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32>
+// CHECK: }
+
+// -----
+
+func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+ %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = "tosa.concat"(%arg1, %arg1) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
+ return %2 : tensor<1x4x7x7xf32>
+}
+
+// CHECK-LABEL: func.func @wide_fold(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
+// CHECK: }
+
+// -----
+
+func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
+ %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
+ %1 = "tosa.concat"(%arg1, %arg1) {axis = 2} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+ %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+ return %2 : tensor<1x4x8x8xf32>
+}
+
+// CHECK-LABEL: func.func @partially_foldable(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) <{axis = 2 : i64}> : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_2]]) <{axis = 1 : i64}> : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
+// CHECK: }
More information about the Mlir-commits
mailing list