[Mlir-commits] [mlir] [mlir][tosa] Add concat_shape Op folder (PR #183293)
Udaya Ranga
llvmlistbot at llvm.org
Thu Mar 12 10:52:33 PDT 2026
https://github.com/udaya-ranga updated https://github.com/llvm/llvm-project/pull/183293
>From 785835f25bb1018a912bedd20ed4de313ada3f21 Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Tue, 27 Jan 2026 15:08:25 +0000
Subject: [PATCH 1/3] CONCAT_SHAPE Op folder
Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
Change-Id: Icdedada797ec7be9825e31d44a11b80832891ea1
---
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 2 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 38 ++++++++
mlir/test/Dialect/Tosa/constant_folding.mlir | 86 +++++++++++++++++++
3 files changed, 126 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 3fc9ca5d810fe..f2c48b7684c26 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -89,6 +89,8 @@ def Tosa_ConcatShapeOp : Tosa_ShapeOp<"concat_shape", [Pure]> {
let results = (outs Tosa_Shape:$output);
let hasVerifier = 1;
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7a1dbcd3e84c7..b8c7dab7ddd68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -2020,6 +2020,40 @@ OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(resultAttrTy, dimSize);
}
+OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
+ auto const inputs = op->getInput();
+
+ if (inputs.empty())
+ return {};
+
+ SmallVector<APInt> concatDims;
+ concatDims.reserve(/*max elem*/ 64);
+ for (auto const &v : inputs) {
+ auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
+ if (!vConstShape)
+ return {};
+
+ const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
+ if (!vAttr)
+ return {};
+
+ const auto vETy = vAttr.getElementType();
+ (void)vETy;
+ assert(vETy.isIntOrIndex());
+ auto const vAttrVals = vAttr.getValues<APInt>();
+ for (auto const &v : vAttrVals) {
+ concatDims.push_back(v);
+ }
+ }
+
+ auto *ctx = op->getContext();
+ assert(ctx != nullptr && "ctx is nullptr");
+ auto const rankedTy = RankedTensorType::get(
+ {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
+
+ return DenseElementsAttr::get(rankedTy, concatDims);
+}
+
OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
return binaryFold<AddShapeOp, AddFoldAdaptor>(this);
}
@@ -2063,3 +2097,7 @@ OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
return unaryShapeFold<Log2FloorShapeOp, Log2FloorFoldAdaptor>(this);
}
+
+OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
+ return concatShapeFold(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 47b9e36b08558..c8276a56b8ddf 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1371,3 +1371,89 @@ func.func @test_log2_floor_shape_neg() -> !tosa.shape<6> {
%c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
return %c : !tosa.shape<6>
}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape
+// CHECK: tosa.const_shape {values = dense<[4, 9, 5, 19]> : tensor<4xindex>} : () -> !tosa.shape<4>
+func.func @test_concat_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[4, 9]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %b = tosa.const_shape {values = dense<[5, 19]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %c = tosa.concat_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_rank6
+// CHECK: tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> : tensor<12xindex>} : () -> !tosa.shape<12>
+func.func @test_concat_shape_rank6() -> !tosa.shape<12> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[7, 8, 9, 10, 11, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.concat_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<12>
+ return %c : !tosa.shape<12>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_total_rank12_5inputs
+// CHECK-DAG: %[[C:.*]] = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[D:.*]] = tosa.const_shape {values = dense<10> : tensor<1xindex>} : () -> !tosa.shape<1>
+// CHECK-DAG: %[[E:.*]] = tosa.const_shape {values = dense<[11, 12]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK-DAG: %[[AB:.*]] = tosa.const_shape {values = dense<[1, 2, 3, 4, 5]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[ABC:.*]] = tosa.concat_shape %[[AB]], %[[C]] : (!tosa.shape<5>, !tosa.shape<4>) -> !tosa.shape<9>
+// CHECK: %[[ABCD:.*]] = tosa.concat_shape %[[ABC]], %[[D]] : (!tosa.shape<9>, !tosa.shape<1>) -> !tosa.shape<10>
+// CHECK: %[[ABCDE:.*]] = tosa.concat_shape %[[ABCD]], %[[E]] : (!tosa.shape<10>, !tosa.shape<2>) -> !tosa.shape<12>
+// CHECK: return %[[ABCDE]] : !tosa.shape<12>
+func.func @test_concat_shape_total_rank12_5inputs() -> !tosa.shape<12> {
+ // Ranks: 3 + 2 + 4 + 1 + 2 = 12
+ %a = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %b = tosa.const_shape {values = dense<[4, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %c = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %d = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %e = tosa.const_shape {values = dense<[11, 12]> : tensor<2xindex>} : () -> !tosa.shape<2>
+
+ %ab = tosa.concat_shape %a, %b : (!tosa.shape<3>, !tosa.shape<2>) -> !tosa.shape<5>
+ %abc = tosa.concat_shape %ab, %c : (!tosa.shape<5>, !tosa.shape<4>) -> !tosa.shape<9>
+ %abcd = tosa.concat_shape %abc, %d : (!tosa.shape<9>, !tosa.shape<1>) -> !tosa.shape<10>
+ %abcde = tosa.concat_shape %abcd, %e : (!tosa.shape<10>, !tosa.shape<2>) -> !tosa.shape<12>
+
+ return %abcde : !tosa.shape<12>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_rank6_4inputs
+// CHECK: tosa.concat_shape
+func.func @test_concat_shape_rank6_4inputs() -> !tosa.shape<24> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[7, 8, 9, 10, 11, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.const_shape {values = dense<[13, 14, 15, 16, 17, 18]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %d = tosa.const_shape {values = dense<[19, 20, 21, 22, 23, 24]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %ab = tosa.concat_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<12>
+ %cd = tosa.concat_shape %c, %d : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<12>
+ %abcd = tosa.concat_shape %ab, %cd : (!tosa.shape<12>, !tosa.shape<12>) -> !tosa.shape<24>
+ return %abcd : !tosa.shape<24>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_total_rank13_5inputs
+// CHECK: tosa.concat_shape
+func.func @test_concat_shape_total_rank13_5inputs() -> !tosa.shape<13> {
+ // Ranks: 3 + 2 + 4 + 1 + 3 = 13
+ %a = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %b = tosa.const_shape {values = dense<[4, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %c = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %d = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %e = tosa.const_shape {values = dense<[11, 12, 13]> : tensor<3xindex>} : () -> !tosa.shape<3>
+
+ %ab = tosa.concat_shape %a, %b : (!tosa.shape<3>, !tosa.shape<2>) -> !tosa.shape<5>
+ %abc = tosa.concat_shape %ab, %c : (!tosa.shape<5>, !tosa.shape<4>) -> !tosa.shape<9>
+ %abcd = tosa.concat_shape %abc, %d : (!tosa.shape<9>, !tosa.shape<1>) -> !tosa.shape<10>
+ %abcde = tosa.concat_shape %abcd, %e : (!tosa.shape<10>, !tosa.shape<3>) -> !tosa.shape<13>
+
+ return %abcde : !tosa.shape<13>
+}
+// -----
+
>From 5f020d287280d65908f8127232682137039eb63b Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Tue, 27 Jan 2026 15:08:25 +0000
Subject: [PATCH 2/3] CONCAT_SHAPE Op folder
Apply compile time folding for TOSA.CONCA_SHAPE
Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 +---
mlir/test/Dialect/Tosa/constant_folding.mlir | 13 +++++++++++++
2 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 59668c7690571..7c39ab3474613 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -2077,11 +2077,9 @@ OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
return {};
const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
- if (!vAttr)
- return {};
+ assert(vAttr);
const auto vETy = vAttr.getElementType();
- (void)vETy;
assert(vETy.isIntOrIndex());
auto const vAttrVals = vAttr.getValues<APInt>();
for (auto const &v : vAttrVals) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index c8276a56b8ddf..dd35f777bebb2 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1457,3 +1457,16 @@ func.func @test_concat_shape_total_rank13_5inputs() -> !tosa.shape<13> {
}
// -----
+// CHECK-LABEL: @test_concat_shape_total_rank9_shapes
+// CHECK: tosa.const_shape
+func.func @test_concat_shape_total_rank9_shapes() -> !tosa.shape<9> {
+ // Ranks: 3 + 2 + 4 = 9
+ %a = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %b = tosa.const_shape {values = dense<[4, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %c = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+
+ %abc = tosa.concat_shape %a, %b, %c : (!tosa.shape<3>, !tosa.shape<2>, !tosa.shape<4>) -> !tosa.shape<9>
+
+ return %abc : !tosa.shape<9>
+}
+// -----
>From 41b95cbc3912566e73fee0adf3c7ab96cdc5fc21 Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Tue, 27 Jan 2026 15:08:25 +0000
Subject: [PATCH 3/3] CONCAT_SHAPE Op folder
Apply compile time folding for TOSA.CONCA_SHAPE
Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7c39ab3474613..b5e067f76a979 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -2079,8 +2079,6 @@ OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
assert(vAttr);
- const auto vETy = vAttr.getElementType();
- assert(vETy.isIntOrIndex());
auto const vAttrVals = vAttr.getValues<APInt>();
for (auto const &v : vAttrVals) {
concatDims.push_back(v);
More information about the Mlir-commits
mailing list