[Mlir-commits] [mlir] [mlir][tosa] Limit consecutive concat rewrite to MAX_TENSOR_LIST_SIZE (PR #199051)

Luke Hutton llvmlistbot at llvm.org
Thu May 28 01:32:05 PDT 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/199051

>From 11177054771b4c8e69b2f7783751ee6e7da4cb6d Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 21 May 2026 10:26:40 +0100
Subject: [PATCH] [mlir][tosa] Limit consecutive concat rewrite to
 MAX_TENSOR_LIST_SIZE

Previously folding could produce an operation that would
later be considered invalid due to the number of operands
it has. This change adds a check to prevent rewriting
consecutive concat operations if the resulting operation
has more than MAX_TENSOR_LIST_SIZE operands, based on the
selected target environment level. If no level is specified,
folding will proceed as before.

In addition, this change rewrites the concat folder as a
canonicalization pattern, since it is not a fold of constant
operands. The change also consolidates testing in
caonicalize.mlir.

Change-Id: I41d3b1672f04f41a095674b99e13e4a7efd4fcdb
---
 mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 16 ++--
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  1 -
 mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp        | 10 ++
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 89 +++++++++++-------
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 88 ++++++++++++++++++
 mlir/test/Dialect/Tosa/fold_concats.mlir      | 93 -------------------
 6 files changed, 158 insertions(+), 139 deletions(-)
 delete mode 100644 mlir/test/Dialect/Tosa/fold_concats.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 3189488cd6c6b..704c02596d6e9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -45,6 +45,8 @@ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6,  8192, 8192, 256,
 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
                                               63, 256,        256,        64};
 
+TosaLevel getTosaLevelFromEnum(const Level level);
+
 TargetEnvAttr lookupTargetEnv(Operation *op);
 TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
 
@@ -131,14 +133,7 @@ class TargetEnv {
     return specificationVersion;
   }
 
-  TosaLevel getLevel() const {
-    if (level == Level::eightK)
-      return TOSA_LEVEL_EIGHTK;
-    else if (level == Level::none)
-      return TOSA_LEVEL_NONE;
-    else
-      llvm_unreachable("Unknown TOSA level");
-  };
+  TosaLevel getLevel() const { return level; };
 
   // Returns true if the given profile is allowed.
   bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -168,13 +163,14 @@ class TargetEnv {
   explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
                      const ArrayRef<Profile> &profiles,
                      const ArrayRef<Extension> &extensions)
-      : specificationVersion(specificationVersion), level(level) {
+      : specificationVersion(specificationVersion),
+        level(getTosaLevelFromEnum(level)) {
     enabledProfiles.insert_range(profiles);
     enabledExtensions.insert_range(extensions);
   }
 
   TosaSpecificationVersion specificationVersion;
-  Level level;
+  TosaLevel level;
   llvm::SmallSet<Profile, 3> enabledProfiles;
   llvm::SmallSet<Extension, 13> enabledExtensions;
 };
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e135265b99881..a99fb2fcae547 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2206,7 +2206,6 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat", [Pure]> {
   ];
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 35122afa430c3..dc18fcaa04c8a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,16 @@
 namespace mlir {
 namespace tosa {
 
+TosaLevel getTosaLevelFromEnum(const Level level) {
+  switch (level) {
+  case Level::eightK:
+    return TOSA_LEVEL_EIGHTK;
+  case Level::none:
+    return TOSA_LEVEL_NONE;
+  }
+  llvm_unreachable("Unknown TOSA level");
+}
+
 llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
   return llvm::formatv("{0}.{1}{2}", version.getMajor(), version.getMinor(),
                        version.isDraft() ? ".draft" : "");
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4af185a6e534b..a53ea689c2b72 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
@@ -367,9 +368,61 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
   }
 };
 
+struct ConsecutiveConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
+  using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ConcatOp op,
+                                PatternRewriter &rewriter) const override {
+    // Rewrite consecutive concats on the same axis into a single op.
+    // Keep track of the operands so we are able to construct a new concat
+    // later. Conservatively assume that we double the number of operands when
+    // canonicalizing
+    SmallVector<Value, 8> concatOperands;
+    concatOperands.reserve(2 * op.getNumOperands());
+
+    int32_t maxNumOperands = 0;
+    if (auto targetEnvAttr = tosa::lookupTargetEnv(op))
+      maxNumOperands =
+          getTosaLevelFromEnum(targetEnvAttr.getLevel()).MAX_TENSOR_LIST_SIZE;
+
+    // Find all operands that are foldable concats
+    bool foundRewritableConcat = false;
+    for (Value operand : op.getOperands()) {
+      concatOperands.emplace_back(operand);
+
+      auto producer = operand.getDefiningOp<tosa::ConcatOp>();
+      if (!producer)
+        continue;
+
+      // Not rewritable if axes are not the same
+      if (op.getAxis() != producer.getAxis())
+        continue;
+
+      // Replace the original operand with all incoming operands
+      foundRewritableConcat = true;
+      concatOperands.pop_back();
+      llvm::append_range(concatOperands, producer->getOperands());
+    }
+
+    if (!foundRewritableConcat)
+      return rewriter.notifyMatchFailure(op,
+                                         "No rewritable concat operand found.");
+
+    if (maxNumOperands > 0 &&
+        concatOperands.size() > static_cast<size_t>(maxNumOperands))
+      return rewriter.notifyMatchFailure(
+          op, "Rewriting would exceed the maximum number of operands for the "
+              "target environment level.");
+
+    rewriter.replaceOpWithNewOp<tosa::ConcatOp>(
+        op, op.getType(), concatOperands, op.getAxisAttr());
+    return success();
+  }
+};
+
 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<ConcatOptimization>(context);
+  results.add<ConcatOptimization, ConsecutiveConcatOptimization>(context);
 }
 
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
@@ -2223,40 +2276,6 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
-  // Fold consecutive concats on the same axis into a single op.
-  // Keep track of the operands so we are able to construct a new concat
-  // later. Conservatively assume that we double the number of operands when
-  // folding
-  SmallVector<Value, 8> concatOperands;
-  concatOperands.reserve(2 * getNumOperands());
-
-  // Find all operands that are foldable concats
-  bool foundFoldableConcat = false;
-  for (Value operand : getOperands()) {
-    concatOperands.emplace_back(operand);
-
-    auto producer = operand.getDefiningOp<ConcatOp>();
-    if (!producer)
-      continue;
-
-    // Not foldable if axes are not the same
-    if (getAxis() != producer.getAxis())
-      continue;
-
-    // Replace the original operand with all incoming operands
-    foundFoldableConcat = true;
-    concatOperands.pop_back();
-    llvm::append_range(concatOperands, producer->getOperands());
-  }
-
-  if (!foundFoldableConcat)
-    return {};
-
-  getOperation()->setOperands(concatOperands);
-  return getResult();
-}
-
 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
   auto input = adaptor.getInput1();
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index d4f3d23fd761e..2cd040f056db8 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1762,3 +1762,91 @@ func.func @dont_canonicalize_non_const_avg_pool2d_adaptive(%arg0: tensor<1x?x?x8
           (tensor<1x?x?x8xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
   return %0 : tensor<1x?x?x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_single_concat
+// CHECK: %[[VAL_1:.*]] = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32>
+func.func @test_single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  return %0 : tensor<1x2x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concat_different_axis
+// CHECK: %[[VAL_1:.*]] = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32>
+func.func @test_concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %0, %0 {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+  return %1 : tensor<2x2x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_fold_concats
+// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %arg0, %arg0, %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
+func.func @test_fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+  %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+  return %1 : tensor<1x4x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_nested_fold
+// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %arg0, %arg0, %[[VAL_1]], %[[VAL_1]], %arg0, %arg0, %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32>
+func.func @test_nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
+  %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+  %2 = tosa.concat %1, %1 {axis = 1 : i32} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
+  return %2 : tensor<1x8x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_nested_fold_too_many_operands
+// CHECK: %[[VAL_1:.*]] = tosa.concat %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %arg0 {axis = 1 : i32} : (tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x65x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x65x7x7xf32>
+module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [pro_fp], extensions = [int16]>} {
+  func.func @test_nested_fold_too_many_operands(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x65x7x7xf32> {
+    %0 = tosa.concat %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+    %1 = tosa.concat %0, %0, %0, %0, %0, %0, %0, %0, %arg0 {axis = 1 : i32} : (tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x65x7x7xf32>
+    return %1 : tensor<1x65x7x7xf32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: test_wide_fold
+// CHECK: %[[VAL_2:.*]] = tosa.concat %arg0, %arg0, %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
+func.func @test_wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+  %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
+  return %2 : tensor<1x4x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_partially_foldable
+// CHECK: %[[VAL_2:.*]] = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+// CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg0, %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
+func.func @test_partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
+  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
+  %1 = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+  %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+  return %2 : tensor<1x4x8x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir
deleted file mode 100644
index ec54f27346c8b..0000000000000
--- a/mlir/test/Dialect/Tosa/fold_concats.mlir
+++ /dev/null
@@ -1,93 +0,0 @@
-// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
-
-func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  return %0 : tensor<1x2x7x7xf32>
-}
-
-// CHECK-LABEL:   func.func @single_concat(
-// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-// CHECK:           return %[[VAL_1]] : tensor<1x2x7x7xf32>
-// CHECK:         }
-
-// -----
-
-func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %0, %0 {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
-  return %1 : tensor<2x2x7x7xf32>
-}
-
-// CHECK-LABEL:   func.func @concat_different_axis(
-// CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
-// CHECK:           return %[[VAL_2]] : tensor<2x2x7x7xf32>
-// CHECK:         }
-
-// -----
-
-func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-  %tmp = tensor.empty() : tensor<1x1x7x7xf32>
-  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-  return %1 : tensor<1x4x7x7xf32>
-}
-
-// CHECK-LABEL:   func.func @fold_concats(
-// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-// CHECK:           return %[[VAL_2]] : tensor<1x4x7x7xf32>
-// CHECK:         }
-
-// -----
-
-func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
-  %tmp = tensor.empty() : tensor<1x1x7x7xf32>
-  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-  %2 = tosa.concat %1, %1 {axis = 1 : i32} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
-  return %2 : tensor<1x8x7x7xf32>
-}
-
-// CHECK-LABEL:   func.func @nested_fold(
-// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
-// CHECK:           return %[[VAL_2]] : tensor<1x8x7x7xf32>
-// CHECK:         }
-
-// -----
-
-func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-  %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
-  return %2 : tensor<1x4x7x7xf32>
-}
-
-// CHECK-LABEL:   func.func @wide_fold(
-// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x1x7x7xf32>,
-// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-// CHECK:           return %[[VAL_2]] : tensor<1x4x7x7xf32>
-// CHECK:         }
-
-// -----
-
-func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
-  %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
-  %1 = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
-  %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
-  return %2 : tensor<1x4x8x8xf32>
-}
-
-// CHECK-LABEL:   func.func @partially_foldable(
-// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
-// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
-// CHECK:           %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
-// CHECK:           %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
-// CHECK:           return %[[VAL_3]] : tensor<1x4x8x8xf32>
-// CHECK:         }



More information about the Mlir-commits mailing list