[Mlir-commits] [mlir] [mlir][tosa] Add concat_shape Op folder (PR #183293)

Udaya Ranga llvmlistbot at llvm.org
Thu Mar 12 10:52:33 PDT 2026


https://github.com/udaya-ranga updated https://github.com/llvm/llvm-project/pull/183293

>From 785835f25bb1018a912bedd20ed4de313ada3f21 Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Tue, 27 Jan 2026 15:08:25 +0000
Subject: [PATCH 1/3] CONCAT_SHAPE Op folder

Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
Change-Id: Icdedada797ec7be9825e31d44a11b80832891ea1
---
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      |  2 +
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 38 ++++++++
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 86 +++++++++++++++++++
 3 files changed, 126 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 3fc9ca5d810fe..f2c48b7684c26 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -89,6 +89,8 @@ def Tosa_ConcatShapeOp : Tosa_ShapeOp<"concat_shape", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 
   let hasVerifier = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7a1dbcd3e84c7..b8c7dab7ddd68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -2020,6 +2020,40 @@ OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
   return DenseElementsAttr::get(resultAttrTy, dimSize);
 }
 
+OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
+  auto const inputs = op->getInput();
+
+  if (inputs.empty())
+    return {};
+
+  SmallVector<APInt> concatDims;
+  concatDims.reserve(/*max elem*/ 64);
+  for (auto const &v : inputs) {
+    auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
+    if (!vConstShape)
+      return {};
+
+    const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
+    if (!vAttr)
+      return {};
+
+    const auto vETy = vAttr.getElementType();
+    (void)vETy;
+    assert(vETy.isIntOrIndex());
+    auto const vAttrVals = vAttr.getValues<APInt>();
+    for (auto const &v : vAttrVals) {
+      concatDims.push_back(v);
+    }
+  }
+
+  auto *ctx = op->getContext();
+  assert(ctx != nullptr && "ctx is nullptr");
+  auto const rankedTy = RankedTensorType::get(
+      {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
+
+  return DenseElementsAttr::get(rankedTy, concatDims);
+}
+
 OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
   return binaryFold<AddShapeOp, AddFoldAdaptor>(this);
 }
@@ -2063,3 +2097,7 @@ OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
 OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
   return unaryShapeFold<Log2FloorShapeOp, Log2FloorFoldAdaptor>(this);
 }
+
+OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
+  return concatShapeFold(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 47b9e36b08558..c8276a56b8ddf 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1371,3 +1371,89 @@ func.func @test_log2_floor_shape_neg() -> !tosa.shape<6> {
   %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
   return %c : !tosa.shape<6>
 }
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape
+// CHECK: tosa.const_shape  {values = dense<[4, 9, 5, 19]> : tensor<4xindex>} : () -> !tosa.shape<4>
+func.func @test_concat_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[4, 9]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[5, 19]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %c = tosa.concat_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_rank6
+// CHECK: tosa.const_shape  {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> : tensor<12xindex>} : () -> !tosa.shape<12>
+func.func @test_concat_shape_rank6() -> !tosa.shape<12> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[7, 8, 9, 10, 11, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.concat_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<12>
+  return %c : !tosa.shape<12>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_total_rank12_5inputs
+// CHECK-DAG: %[[C:.*]] = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[D:.*]] = tosa.const_shape {values = dense<10> : tensor<1xindex>} : () -> !tosa.shape<1>
+// CHECK-DAG: %[[E:.*]] = tosa.const_shape {values = dense<[11, 12]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK-DAG: %[[AB:.*]] = tosa.const_shape {values = dense<[1, 2, 3, 4, 5]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[ABC:.*]] = tosa.concat_shape %[[AB]], %[[C]] : (!tosa.shape<5>, !tosa.shape<4>) -> !tosa.shape<9>
+// CHECK: %[[ABCD:.*]] = tosa.concat_shape %[[ABC]], %[[D]] : (!tosa.shape<9>, !tosa.shape<1>) -> !tosa.shape<10>
+// CHECK: %[[ABCDE:.*]] = tosa.concat_shape %[[ABCD]], %[[E]] : (!tosa.shape<10>, !tosa.shape<2>) -> !tosa.shape<12>
+// CHECK: return %[[ABCDE]] : !tosa.shape<12>
+func.func @test_concat_shape_total_rank12_5inputs() -> !tosa.shape<12> {
+  // Ranks: 3 + 2 + 4 + 1 + 2 = 12
+  %a = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %b = tosa.const_shape {values = dense<[4, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %c = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %d = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %e = tosa.const_shape {values = dense<[11, 12]> : tensor<2xindex>} : () -> !tosa.shape<2>
+
+  %ab = tosa.concat_shape %a, %b : (!tosa.shape<3>, !tosa.shape<2>) -> !tosa.shape<5>
+  %abc = tosa.concat_shape %ab, %c : (!tosa.shape<5>, !tosa.shape<4>) -> !tosa.shape<9>
+  %abcd = tosa.concat_shape %abc, %d : (!tosa.shape<9>, !tosa.shape<1>) -> !tosa.shape<10>
+  %abcde = tosa.concat_shape %abcd, %e : (!tosa.shape<10>, !tosa.shape<2>) -> !tosa.shape<12>
+
+  return %abcde : !tosa.shape<12>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_rank6_4inputs
+// CHECK: tosa.concat_shape
+func.func @test_concat_shape_rank6_4inputs() -> !tosa.shape<24> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[7, 8, 9, 10, 11, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.const_shape {values = dense<[13, 14, 15, 16, 17, 18]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %d = tosa.const_shape {values = dense<[19, 20, 21, 22, 23, 24]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %ab = tosa.concat_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<12>
+  %cd = tosa.concat_shape %c, %d : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<12>
+  %abcd = tosa.concat_shape %ab, %cd : (!tosa.shape<12>, !tosa.shape<12>) -> !tosa.shape<24>
+  return %abcd : !tosa.shape<24>
+}
+
+// -----
+
+// CHECK-LABEL: @test_concat_shape_total_rank13_5inputs
+// CHECK: tosa.concat_shape
+func.func @test_concat_shape_total_rank13_5inputs() -> !tosa.shape<13> {
+  // Ranks: 3 + 2 + 4 + 1 + 3 = 13
+  %a = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %b = tosa.const_shape {values = dense<[4, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %c = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %d = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %e = tosa.const_shape {values = dense<[11, 12, 13]> : tensor<3xindex>} : () -> !tosa.shape<3>
+
+  %ab = tosa.concat_shape %a, %b : (!tosa.shape<3>, !tosa.shape<2>) -> !tosa.shape<5>
+  %abc = tosa.concat_shape %ab, %c : (!tosa.shape<5>, !tosa.shape<4>) -> !tosa.shape<9>
+  %abcd = tosa.concat_shape %abc, %d : (!tosa.shape<9>, !tosa.shape<1>) -> !tosa.shape<10>
+  %abcde = tosa.concat_shape %abcd, %e : (!tosa.shape<10>, !tosa.shape<3>) -> !tosa.shape<13>
+
+  return %abcde : !tosa.shape<13>
+}
+// -----
+

>From 5f020d287280d65908f8127232682137039eb63b Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Tue, 27 Jan 2026 15:08:25 +0000
Subject: [PATCH 2/3] CONCAT_SHAPE Op folder

Apply compile time folding for TOSA.CONCA_SHAPE

Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp |  4 +---
 mlir/test/Dialect/Tosa/constant_folding.mlir       | 13 +++++++++++++
 2 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 59668c7690571..7c39ab3474613 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -2077,11 +2077,9 @@ OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
       return {};
 
     const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
-    if (!vAttr)
-      return {};
+    assert(vAttr);
 
     const auto vETy = vAttr.getElementType();
-    (void)vETy;
     assert(vETy.isIntOrIndex());
     auto const vAttrVals = vAttr.getValues<APInt>();
     for (auto const &v : vAttrVals) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index c8276a56b8ddf..dd35f777bebb2 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1457,3 +1457,16 @@ func.func @test_concat_shape_total_rank13_5inputs() -> !tosa.shape<13> {
 }
 // -----
 
+// CHECK-LABEL: @test_concat_shape_total_rank9_shapes
+// CHECK: tosa.const_shape
+func.func @test_concat_shape_total_rank9_shapes() -> !tosa.shape<9> {
+  // Ranks: 3 + 2 + 4 = 9
+  %a = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %b = tosa.const_shape {values = dense<[4, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %c = tosa.const_shape {values = dense<[6, 7, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4>
+
+  %abc = tosa.concat_shape %a, %b, %c : (!tosa.shape<3>, !tosa.shape<2>, !tosa.shape<4>) -> !tosa.shape<9>
+
+  return %abc : !tosa.shape<9>
+}
+// -----

>From 41b95cbc3912566e73fee0adf3c7ab96cdc5fc21 Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Tue, 27 Jan 2026 15:08:25 +0000
Subject: [PATCH 3/3] CONCAT_SHAPE Op folder

Apply compile time folding for TOSA.CONCA_SHAPE

Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7c39ab3474613..b5e067f76a979 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -2079,8 +2079,6 @@ OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
     const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
     assert(vAttr);
 
-    const auto vETy = vAttr.getElementType();
-    assert(vETy.isIntOrIndex());
     auto const vAttrVals = vAttr.getValues<APInt>();
     for (auto const &v : vAttrVals) {
       concatDims.push_back(v);



More information about the Mlir-commits mailing list