[Mlir-commits] [mlir] [mlir][tosa] Add constant folding support for `tosa.dim` (PR #176975)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 20 09:47:59 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This enhances shape inference.

---
Full diff: https://github.com/llvm/llvm-project/pull/176975.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+2) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+15) 
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d8597151714c3..9d123dde562e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -137,6 +137,8 @@ def Tosa_DimOp : Tosa_ShapeOp<"dim", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 
   let hasVerifier = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b15a3a4279064..c5b4299e6ce6d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1683,3 +1683,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+
+OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
+  const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!inputTy || !inputTy.hasRank())
+    return {};
+  const int32_t axis = getAxis();
+  const int64_t dimSize = inputTy.getDimSize(axis);
+  if (ShapedType::isDynamic(dimSize))
+    return {};
+
+  OpBuilder builder(getContext());
+  const int64_t rank = cast<tosa::shapeType>(getResult().getType()).getRank();
+  const auto resultAttrTy = RankedTensorType::get(rank, builder.getIndexType());
+  return DenseElementsAttr::get(resultAttrTy, dimSize);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 8c375b6c528ef..7860127947bfa 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -650,3 +650,30 @@ func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>)
   %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32>
   return %1 : tensor<44x57xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_fold_dim
+// CHECK: tosa.const_shape  {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+func.func @test_fold_dim(%arg0: tensor<6xi32>) -> !tosa.shape<1> {
+  %dim = tosa.dim %arg0 {axis = 0 : i32} : (tensor<6xi32>) -> !tosa.shape<1>
+  return %dim : !tosa.shape<1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_dim_unranked_input
+// CHECK: tosa.dim
+func.func @test_no_fold_dim_unranked_input(%arg0: tensor<*xi32>) -> !tosa.shape<1> {
+  %dim = tosa.dim %arg0 {axis = 0 : i32} : (tensor<*xi32>) -> !tosa.shape<1>
+  return %dim : !tosa.shape<1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_dim_dynamic
+// CHECK: tosa.dim
+func.func @test_no_fold_dim_dynamic(%arg0: tensor<4x?xi32>) -> !tosa.shape<1> {
+  %dim = tosa.dim %arg0 {axis = 1 : i32} : (tensor<4x?xi32>) -> !tosa.shape<1>
+  return %dim : !tosa.shape<1>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/176975


More information about the Mlir-commits mailing list