[Mlir-commits] [mlir] 295a6ed - [TOSA] Fold consecutive concats on same axis

Dominik Montada llvmlistbot at llvm.org
Wed May 24 03:02:38 PDT 2023


Author: Dominik Montada
Date: 2023-05-24T10:00:56Z
New Revision: 295a6ed5d54aca2d923b76b6388f9732eb37f548

URL: https://github.com/llvm/llvm-project/commit/295a6ed5d54aca2d923b76b6388f9732eb37f548
DIFF: https://github.com/llvm/llvm-project/commit/295a6ed5d54aca2d923b76b6388f9732eb37f548.diff

LOG: [TOSA] Fold consecutive concats on same axis

Consecutive concats that happen on the same axis can be folded into a
single, bigger concat. This patch implements this folding by
implementing the tosa::ConcatOp::fold method.

Differential Revision: https://reviews.llvm.org/D151210

Added: 
    mlir/test/Dialect/Tosa/fold_concats.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f064e7a180441..359690db1eb7b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1500,6 +1500,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
   );
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 289f19559ad64..a70749e89a887 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1037,3 +1037,37 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+
+OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
+  // Fold consecutive concats on the same axis into a single op.
+  // Keep track of the operands so we are able to construct a new concat
+  // later. Conservatively assume that we double the number of operands when
+  // folding
+  SmallVector<Value, 8> concatOperands;
+  concatOperands.reserve(2 * getNumOperands());
+
+  // Find all operands that are foldable concats
+  bool foundFoldableConcat = false;
+  for (Value operand : getOperands()) {
+    concatOperands.emplace_back(operand);
+
+    auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
+    if (!producer)
+      continue;
+
+    // Not foldable if axes are not the same
+    if (getAxis() != producer.getAxis())
+      continue;
+
+    // Replace the original operand with all incoming operands
+    foundFoldableConcat = true;
+    concatOperands.pop_back();
+    llvm::append_range(concatOperands, producer->getOperands());
+  }
+
+  if (!foundFoldableConcat)
+    return {};
+
+  getOperation()->setOperands(concatOperands);
+  return getResult();
+}

diff  --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir
new file mode 100644
index 0000000000000..781d527cc32f2
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/fold_concats.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
+
+func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
+  %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  return %0 : tensor<1x2x7x7xf32>
+}
+
+// CHECK-LABEL:   func.func @single_concat(
+// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
+// CHECK:           %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK:           return %[[VAL_1]] : tensor<1x2x7x7xf32>
+// CHECK:         }
+
+// -----
+
+func.func @concat_
diff erent_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
+  %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = "tosa.concat"(%0, %0) {axis = 0} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+  return %1 : tensor<2x2x7x7xf32>
+}
+
+// CHECK-LABEL:   func.func @concat_
diff erent_axis(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
+// CHECK:           %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) <{axis = 0 : i64}> : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+// CHECK:           return %[[VAL_2]] : tensor<2x2x7x7xf32>
+// CHECK:         }
+
+// -----
+
+func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+  %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+  %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+  return %1 : tensor<1x4x7x7xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_concats(
+// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+// CHECK:           %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK:           return %[[VAL_2]] : tensor<1x4x7x7xf32>
+// CHECK:         }
+
+// -----
+
+func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
+  %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+  %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+  %2 = "tosa.concat"(%1, %1) {axis = 1} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
+  return %2 : tensor<1x8x7x7xf32>
+}
+
+// CHECK-LABEL:   func.func @nested_fold(
+// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
+// CHECK:           %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK:           %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK:           return %[[VAL_2]] : tensor<1x8x7x7xf32>
+// CHECK:         }
+
+// -----
+
+func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+  %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = "tosa.concat"(%arg1, %arg1) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
+  return %2 : tensor<1x4x7x7xf32>
+}
+
+// CHECK-LABEL:   func.func @wide_fold(
+// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x1x7x7xf32>,
+// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+// CHECK:           %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK:           return %[[VAL_2]] : tensor<1x4x7x7xf32>
+// CHECK:         }
+
+// -----
+
+func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
+  %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
+  %1 = "tosa.concat"(%arg1, %arg1) {axis = 2} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+  %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+  return %2 : tensor<1x4x8x8xf32>
+}
+
+// CHECK-LABEL:   func.func @partially_foldable(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
+// CHECK:           %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) <{axis = 2 : i64}> : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+// CHECK:           %[[VAL_3:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_2]]) <{axis = 1 : i64}> : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+// CHECK:           return %[[VAL_3]] : tensor<1x4x8x8xf32>
+// CHECK:         }


        


More information about the Mlir-commits mailing list