[Mlir-commits] [mlir] [mlir][tosa] Limit consecutive concat rewrite to MAX_TENSOR_LIST_SIZE (PR #199051)
Luke Hutton
llvmlistbot at llvm.org
Thu May 28 01:32:05 PDT 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/199051
>From 11177054771b4c8e69b2f7783751ee6e7da4cb6d Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 21 May 2026 10:26:40 +0100
Subject: [PATCH] [mlir][tosa] Limit consecutive concat rewrite to
MAX_TENSOR_LIST_SIZE
Previously folding could produce an operation that would
later be considered invalid due to the number of operands
it has. This change adds a check to prevent rewriting
consecutive concat operations if the resulting operation
has more than MAX_TENSOR_LIST_SIZE operands, based on the
selected target environment level. If no level is specified,
folding will proceed as before.
In addition, this change rewrites the concat folder as a
canonicalization pattern, since it is not a fold of constant
operands. The change also consolidates testing in
caonicalize.mlir.
Change-Id: I41d3b1672f04f41a095674b99e13e4a7efd4fcdb
---
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 16 ++--
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 -
mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 10 ++
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 89 +++++++++++-------
mlir/test/Dialect/Tosa/canonicalize.mlir | 88 ++++++++++++++++++
mlir/test/Dialect/Tosa/fold_concats.mlir | 93 -------------------
6 files changed, 158 insertions(+), 139 deletions(-)
delete mode 100644 mlir/test/Dialect/Tosa/fold_concats.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 3189488cd6c6b..704c02596d6e9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -45,6 +45,8 @@ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256,
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
63, 256, 256, 64};
+TosaLevel getTosaLevelFromEnum(const Level level);
+
TargetEnvAttr lookupTargetEnv(Operation *op);
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
@@ -131,14 +133,7 @@ class TargetEnv {
return specificationVersion;
}
- TosaLevel getLevel() const {
- if (level == Level::eightK)
- return TOSA_LEVEL_EIGHTK;
- else if (level == Level::none)
- return TOSA_LEVEL_NONE;
- else
- llvm_unreachable("Unknown TOSA level");
- };
+ TosaLevel getLevel() const { return level; };
// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -168,13 +163,14 @@ class TargetEnv {
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
- : specificationVersion(specificationVersion), level(level) {
+ : specificationVersion(specificationVersion),
+ level(getTosaLevelFromEnum(level)) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
TosaSpecificationVersion specificationVersion;
- Level level;
+ TosaLevel level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
};
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e135265b99881..a99fb2fcae547 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2206,7 +2206,6 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat", [Pure]> {
];
let hasCanonicalizer = 1;
- let hasFolder = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 35122afa430c3..dc18fcaa04c8a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,16 @@
namespace mlir {
namespace tosa {
+TosaLevel getTosaLevelFromEnum(const Level level) {
+ switch (level) {
+ case Level::eightK:
+ return TOSA_LEVEL_EIGHTK;
+ case Level::none:
+ return TOSA_LEVEL_NONE;
+ }
+ llvm_unreachable("Unknown TOSA level");
+}
+
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
return llvm::formatv("{0}.{1}{2}", version.getMajor(), version.getMinor(),
version.isDraft() ? ".draft" : "");
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4af185a6e534b..a53ea689c2b72 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
@@ -367,9 +368,61 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
}
};
+struct ConsecutiveConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
+ using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ConcatOp op,
+ PatternRewriter &rewriter) const override {
+ // Rewrite consecutive concats on the same axis into a single op.
+ // Keep track of the operands so we are able to construct a new concat
+ // later. Conservatively assume that we double the number of operands when
+ // canonicalizing
+ SmallVector<Value, 8> concatOperands;
+ concatOperands.reserve(2 * op.getNumOperands());
+
+ int32_t maxNumOperands = 0;
+ if (auto targetEnvAttr = tosa::lookupTargetEnv(op))
+ maxNumOperands =
+ getTosaLevelFromEnum(targetEnvAttr.getLevel()).MAX_TENSOR_LIST_SIZE;
+
+ // Find all operands that are foldable concats
+ bool foundRewritableConcat = false;
+ for (Value operand : op.getOperands()) {
+ concatOperands.emplace_back(operand);
+
+ auto producer = operand.getDefiningOp<tosa::ConcatOp>();
+ if (!producer)
+ continue;
+
+ // Not rewritable if axes are not the same
+ if (op.getAxis() != producer.getAxis())
+ continue;
+
+ // Replace the original operand with all incoming operands
+ foundRewritableConcat = true;
+ concatOperands.pop_back();
+ llvm::append_range(concatOperands, producer->getOperands());
+ }
+
+ if (!foundRewritableConcat)
+ return rewriter.notifyMatchFailure(op,
+ "No rewritable concat operand found.");
+
+ if (maxNumOperands > 0 &&
+ concatOperands.size() > static_cast<size_t>(maxNumOperands))
+ return rewriter.notifyMatchFailure(
+ op, "Rewriting would exceed the maximum number of operands for the "
+ "target environment level.");
+
+ rewriter.replaceOpWithNewOp<tosa::ConcatOp>(
+ op, op.getType(), concatOperands, op.getAxisAttr());
+ return success();
+ }
+};
+
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ConcatOptimization>(context);
+ results.add<ConcatOptimization, ConsecutiveConcatOptimization>(context);
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
@@ -2223,40 +2276,6 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
return {};
}
-OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
- // Fold consecutive concats on the same axis into a single op.
- // Keep track of the operands so we are able to construct a new concat
- // later. Conservatively assume that we double the number of operands when
- // folding
- SmallVector<Value, 8> concatOperands;
- concatOperands.reserve(2 * getNumOperands());
-
- // Find all operands that are foldable concats
- bool foundFoldableConcat = false;
- for (Value operand : getOperands()) {
- concatOperands.emplace_back(operand);
-
- auto producer = operand.getDefiningOp<ConcatOp>();
- if (!producer)
- continue;
-
- // Not foldable if axes are not the same
- if (getAxis() != producer.getAxis())
- continue;
-
- // Replace the original operand with all incoming operands
- foundFoldableConcat = true;
- concatOperands.pop_back();
- llvm::append_range(concatOperands, producer->getOperands());
- }
-
- if (!foundFoldableConcat)
- return {};
-
- getOperation()->setOperands(concatOperands);
- return getResult();
-}
-
OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
auto input = adaptor.getInput1();
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index d4f3d23fd761e..2cd040f056db8 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1762,3 +1762,91 @@ func.func @dont_canonicalize_non_const_avg_pool2d_adaptive(%arg0: tensor<1x?x?x8
(tensor<1x?x?x8xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
return %0 : tensor<1x?x?x8xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_single_concat
+// CHECK: %[[VAL_1:.*]] = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32>
+func.func @test_single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
+ %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ return %0 : tensor<1x2x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concat_different_axis
+// CHECK: %[[VAL_1:.*]] = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32>
+func.func @test_concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
+ %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = tosa.concat %0, %0 {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
+ return %1 : tensor<2x2x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_fold_concats
+// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %arg0, %arg0, %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
+func.func @test_fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+ %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+ %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+ return %1 : tensor<1x4x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_nested_fold
+// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %arg0, %arg0, %[[VAL_1]], %[[VAL_1]], %arg0, %arg0, %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32>
+func.func @test_nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
+ %tmp = tensor.empty() : tensor<1x1x7x7xf32>
+ %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+ %2 = tosa.concat %1, %1 {axis = 1 : i32} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
+ return %2 : tensor<1x8x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_nested_fold_too_many_operands
+// CHECK: %[[VAL_1:.*]] = tosa.concat %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %arg0 {axis = 1 : i32} : (tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x65x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x65x7x7xf32>
+module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [pro_fp], extensions = [int16]>} {
+ func.func @test_nested_fold_too_many_operands(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x65x7x7xf32> {
+ %0 = tosa.concat %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
+ %1 = tosa.concat %0, %0, %0, %0, %0, %0, %0, %0, %arg0 {axis = 1 : i32} : (tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x8x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x65x7x7xf32>
+ return %1 : tensor<1x65x7x7xf32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: test_wide_fold
+// CHECK: %[[VAL_2:.*]] = tosa.concat %arg0, %arg0, %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
+func.func @test_wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
+ %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
+ %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
+ return %2 : tensor<1x4x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_partially_foldable
+// CHECK: %[[VAL_2:.*]] = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+// CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg0, %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
+func.func @test_partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
+ %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
+ %1 = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
+ %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
+ return %2 : tensor<1x4x8x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir
deleted file mode 100644
index ec54f27346c8b..0000000000000
--- a/mlir/test/Dialect/Tosa/fold_concats.mlir
+++ /dev/null
@@ -1,93 +0,0 @@
-// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
-
-func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
- %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
- return %0 : tensor<1x2x7x7xf32>
-}
-
-// CHECK-LABEL: func.func @single_concat(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
-// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32>
-// CHECK: }
-
-// -----
-
-func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
- %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
- %1 = tosa.concat %0, %0 {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
- return %1 : tensor<2x2x7x7xf32>
-}
-
-// CHECK-LABEL: func.func @concat_different_axis(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
-// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
-// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32>
-// CHECK: }
-
-// -----
-
-func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
- %tmp = tensor.empty() : tensor<1x1x7x7xf32>
- %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
- %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
- return %1 : tensor<1x4x7x7xf32>
-}
-
-// CHECK-LABEL: func.func @fold_concats(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
-// CHECK: }
-
-// -----
-
-func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
- %tmp = tensor.empty() : tensor<1x1x7x7xf32>
- %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
- %1 = tosa.concat %tmp, %0, %tmp {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
- %2 = tosa.concat %1, %1 {axis = 1 : i32} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
- return %2 : tensor<1x8x7x7xf32>
-}
-
-// CHECK-LABEL: func.func @nested_fold(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
-// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
-// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32>
-// CHECK: }
-
-// -----
-
-func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
- %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
- %1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
- %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
- return %2 : tensor<1x4x7x7xf32>
-}
-
-// CHECK-LABEL: func.func @wide_fold(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
-// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
-// CHECK: }
-
-// -----
-
-func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
- %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
- %1 = tosa.concat %arg1, %arg1 {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
- %2 = tosa.concat %0, %1 {axis = 1 : i32} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
- return %2 : tensor<1x4x8x8xf32>
-}
-
-// CHECK-LABEL: func.func @partially_foldable(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
-// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
-// CHECK: }
More information about the Mlir-commits
mailing list