[Mlir-commits] [mlir] [mlir][tosa] Add constant folding support for `tosa.dim` (PR #176975)

Luke Hutton llvmlistbot at llvm.org
Tue Jan 27 09:10:54 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/176975

>From b57da9acc34d100db4be5497ed2047e6419a5e38 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 19 Dec 2025 10:24:19 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add constant folding support for `tosa.dim`

This enhances shape inference.

Change-Id: I4cba9456c0acac2ce8aeb8cdeb69052be664bc21
---
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      |  2 ++
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 15 +++++++++++
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 27 +++++++++++++++++++
 3 files changed, 44 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 1783a5ef7c961..57fd1d2d20aa8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -139,6 +139,8 @@ def Tosa_DimOp : Tosa_ShapeOp<"dim", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 
   let hasVerifier = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5f41c8c3f300f..d4d3df3d9a952 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1764,3 +1764,18 @@ OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
   return binaryFolder<FoldAddAdaptor>(
       input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
 }
+
+OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
+  const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!inputTy || !inputTy.hasRank())
+    return {};
+  const int32_t axis = getAxis();
+  const int64_t dimSize = inputTy.getDimSize(axis);
+  if (ShapedType::isDynamic(dimSize))
+    return {};
+
+  OpBuilder builder(getContext());
+  const int64_t rank = cast<tosa::shapeType>(getResult().getType()).getRank();
+  const auto resultAttrTy = RankedTensorType::get(rank, builder.getIndexType());
+  return DenseElementsAttr::get(resultAttrTy, dimSize);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 1007af6c8bd82..4cb9a46e5d049 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -683,3 +683,30 @@ func.func @test_no_fold_add_shape_negative_overflow() -> !tosa.shape<6> {
   %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
   return %c : !tosa.shape<6>
 }
+
+// -----
+
+// CHECK-LABEL: @test_fold_dim
+// CHECK: tosa.const_shape  {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+func.func @test_fold_dim(%arg0: tensor<6xi32>) -> !tosa.shape<1> {
+  %dim = tosa.dim %arg0 {axis = 0 : i32} : (tensor<6xi32>) -> !tosa.shape<1>
+  return %dim : !tosa.shape<1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_dim_unranked_input
+// CHECK: tosa.dim
+func.func @test_no_fold_dim_unranked_input(%arg0: tensor<*xi32>) -> !tosa.shape<1> {
+  %dim = tosa.dim %arg0 {axis = 0 : i32} : (tensor<*xi32>) -> !tosa.shape<1>
+  return %dim : !tosa.shape<1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_dim_dynamic
+// CHECK: tosa.dim
+func.func @test_no_fold_dim_dynamic(%arg0: tensor<4x?xi32>) -> !tosa.shape<1> {
+  %dim = tosa.dim %arg0 {axis = 1 : i32} : (tensor<4x?xi32>) -> !tosa.shape<1>
+  return %dim : !tosa.shape<1>
+}

>From d41d7b0ae42acf23f7682f6c6954b42211f64c6d Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 Jan 2026 17:07:53 +0000
Subject: [PATCH 2/2] Don't get rank when it is known to be constant.

Change-Id: I553ab9701cf8f5fdfc6f3ec6344fa1545e0044a9
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index d4d3df3d9a952..07a0b6742d48a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1775,7 +1775,7 @@ OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
     return {};
 
   OpBuilder builder(getContext());
-  const int64_t rank = cast<tosa::shapeType>(getResult().getType()).getRank();
-  const auto resultAttrTy = RankedTensorType::get(rank, builder.getIndexType());
+  const auto resultAttrTy =
+      RankedTensorType::get(/*rank=*/1, builder.getIndexType());
   return DenseElementsAttr::get(resultAttrTy, dimSize);
 }



More information about the Mlir-commits mailing list