[Mlir-commits] [mlir] [mlir][tosa] Harden folds/canonicalizations for unranked and dynamic shapes (PR #188188)
Hocky Yudhiono
llvmlistbot at llvm.org
Thu Mar 26 08:33:21 PDT 2026
https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/188188
>From 283218663a6e96be9bbc40c228b537761624d80a Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Wed, 25 Mar 2026 11:51:29 +0800
Subject: [PATCH 1/4] [mlir][tosa] Harden folds/canonicalizations for unranked
and dynamic shapes
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 65 +++++++-----
.../Dialect/Tosa/Transforms/TosaFolders.cpp | 8 ++
mlir/test/Dialect/Tosa/canonicalize.mlir | 98 ++++++++++++++++++-
mlir/test/Dialect/Tosa/constant_folding.mlir | 36 ++++++-
.../Tosa/tosa-layerwise-constant-fold.mlir | 19 ++++
5 files changed, 198 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..b95e6819d6268 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -31,6 +31,13 @@
using namespace mlir;
using namespace mlir::tosa;
+namespace {
+template <typename OpTy>
+OpFoldResult foldToInputIfTypeMatches(OpTy op, Value input) {
+ return input.getType() == op.getType() ? OpFoldResult(input) : OpFoldResult{};
+}
+} // namespace
+
//===----------------------------------------------------------------------===//
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//
@@ -423,7 +430,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
- auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+ auto inputType = llvm::dyn_cast<ShapedType>(op.getInput().getType());
auto inputElementType = inputType.getElementType();
if (isa<FloatType>(inputElementType)) {
@@ -843,6 +850,8 @@ struct SliceDynamicSizeCanonicalization
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const override {
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
+ if (!resultType.hasRank())
+ return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
ElementsAttr sizeElems;
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
@@ -946,6 +955,11 @@ binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
if (!lhs || !rhs)
return {};
+ // DenseElementsAttr::get needs a static shape. Result types may be unranked
+ // (no RankedTensorType) or ranked-dynamic while operands are dense splats.
+ if (!returnTy || !returnTy.hasStaticShape())
+ return {};
+
const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
@@ -994,6 +1008,9 @@ static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
if (!val)
return {};
+ if (!returnTy || !returnTy.hasStaticShape())
+ return {};
+
const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
if (val.isSplat()) {
@@ -1496,7 +1513,7 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
- auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+ auto resultTy = llvm::dyn_cast<ShapedType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
@@ -1509,7 +1526,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
- auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+ auto resultTy = llvm::dyn_cast<ShapedType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
@@ -1522,7 +1539,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
- auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+ auto resultTy = llvm::dyn_cast<ShapedType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
@@ -1554,6 +1571,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto inTy = llvm::cast<ShapedType>(getInput().getType());
auto outTy = llvm::cast<ShapedType>(getType());
+ if (!outTy.hasStaticShape())
+ return {};
auto inETy = inTy.getElementType();
auto outETy = outTy.getElementType();
@@ -1735,29 +1754,22 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
return {};
}
- auto input = getInput();
- auto inputTy = llvm::cast<RankedTensorType>(input.getType());
- auto resultTy = llvm::cast<RankedTensorType>(getType());
- if (inputTy != resultTy)
- return {};
-
- return input;
+ return foldToInputIfTypeMatches(*this, getInput());
}
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput1();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
- auto operandAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
- if (operandAttr)
- return operandAttr;
-
- // If the dim-length is 1, tosa.reverse is a no-op.
- if (operandTy.hasRank() &&
- (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
- return operand;
-
+ // Check if the reverse is a no-op
+ // If the operand is a splat, the reverse is a no-op.
+ bool noOpReverse =
+ llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
+ // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
+ noOpReverse |= (operandTy.hasRank() && (operandTy.getRank() == 0 ||
+ operandTy.getDimSize(axis) == 1));
+ if (noOpReverse)
+ return foldToInputIfTypeMatches(*this, operand);
return {};
}
@@ -1920,7 +1932,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
- return getInput1();
+ return foldToInputIfTypeMatches(*this, getInput1());
}
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
@@ -1953,15 +1965,14 @@ OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
return {};
}
- return definingOp.getInput1();
+ return foldToInputIfTypeMatches(*this, definingOp.getInput1());
}
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();
// Element-wise abs(abs(x)) = abs(x)
- if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
- return input;
- }
+ if (input.getDefiningOp<tosa::AbsOp>())
+ return foldToInputIfTypeMatches(*this, input);
return {};
}
@@ -2009,6 +2020,8 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
return {};
auto shapeType = llvm::cast<ShapedType>(getType());
+ if (!shapeType.hasStaticShape())
+ return {};
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
auto floatVal = inputAttr.getSplatValue<APFloat>();
return DenseElementsAttr::get(shapeType,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0a035bbd3df00..7f97f8ba0becd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -242,6 +242,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
+ if (!outputType.hasStaticShape())
+ return failure();
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();
@@ -269,6 +271,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
}
};
+/// Fold `tosa.reciprocal` into `tosa.const` when the operand is a dense float
+/// TOSA constant, types match, and `constantUnaryOpShouldBeFolded` allows it.
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
using OpRewritePattern::OpRewritePattern;
@@ -295,6 +299,10 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
"tensor has a single user");
}
+ if (inputTensor.getType() != recip.getType())
+ return rewriter.notifyMatchFailure(
+ recip, "input tensor and reciprocal output have different type");
+
// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &ReciprocalOp::calcOneElement,
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..1f527e3f566ce 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -900,12 +900,77 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
return %resize : tensor<1x15x13x1xi8>
}
// -----
+// ResizeOp::fold: unit scale (1:1 Y and X), zero offset/border, in/out types equal.
+// CHECK-LABEL: @fold_resize_identity_scale
+func.func @fold_resize_identity_scale(%arg0 : tensor<1x15x13x1xf32>) -> tensor<1x15x13x1xf32> {
+ // CHECK-NOT: tosa.resize
+ %scale = tosa.const_shape { values = dense<[1, 1, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xf32>
+ return %resize : tensor<1x15x13x1xf32>
+}
+
+// -----
+// CHECK-LABEL: @fold_resize_identity_scale_to_unranked
+func.func @fold_resize_identity_scale_to_unranked(%arg0 : tensor<1x15x13x1xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.resize
+ %scale = tosa.const_shape { values = dense<[1, 1, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+ return %resize : tensor<*xf32>
+}
+
+// -----
+
+// Same parameters except scale_y_n != scale_y_d: fold must not apply.
+// CHECK-LABEL: @resize_nofold_asymmetric_y_scale
+func.func @resize_nofold_asymmetric_y_scale(%arg0 : tensor<1x15x13x1xf32>) -> tensor<1x29x13x1xf32> {
+ // CHECK: tosa.resize
+ %scale = tosa.const_shape { values = dense<[4, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x29x13x1xf32>
+ return %resize : tensor<1x29x13x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_clamp
+func.func @dont_canonicalize_unranked_clamp(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_clamp
+func.func @dont_canonicalize_unranked_clamp(%arg0 : tensor<*xf32>) -> tensor<1xf32> {
+ // CHECK: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0 : f32} : (tensor<*xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_slice_dynamic_size
+func.func @dont_canonicalize_unranked_slice_dynamic_size(%arg0: tensor<1x4xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.slice
+ %start = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %size = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.slice %arg0, %start, %size : (tensor<1x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @canonicalize_concat_slice_final_axis
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
@@ -1202,6 +1267,37 @@ func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.07574046018
// -----
+// ReverseOp::fold: unranked operand has hasRank() == false;
+// CHECK-LABEL: @reverse_nofold_unranked_operand
+func.func @reverse_nofold_unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.reverse
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// Unit-dim no-op, but mismatch type
+// CHECK-LABEL: @reverse_nofold_unit_dim_unranked_result
+func.func @reverse_nofold_unit_dim_unranked_result(%arg0: tensor<1x4xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.reverse
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<1x4xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// Splat fold returns the operand ElementsAttr; But result type doesn't match.
+// CHECK-LABEL: @reverse_nofold_splat_type_unmatch
+func.func @reverse_nofold_splat_type_unmatch() -> tensor<*xf32> {
+ // CHECK: tosa.reverse
+ %0 = "tosa.const"() <{values = dense<1.0> : tensor<4xf32>}> : () -> tensor<4xf32>
+ %1 = tosa.reverse %0 {axis = 0 : i32} : (tensor<4xf32>) -> tensor<*xf32>
+ return %1 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @select_quant_fold
func.func @select_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
// CHECK: %[[CONST_0:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index dc040d3231964..7a26b5475ffda 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -28,6 +28,41 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
// -----
+// CHECK-LABEL: func @try_fold_unranked_constant_results
+func.func @try_fold_unranked_constant_results() {
+ // CHECK: tosa.equal
+ // CHECK: tosa.greater
+ // CHECK: tosa.greater_equal
+ // CHECK: tosa.cast
+ // CHECK: tosa.reciprocal
+ // CHECK-NEXT: return
+ %lhs = arith.constant dense<1> : tensor<1xi32>
+ %rhs = arith.constant dense<2> : tensor<1xi32>
+ %f = arith.constant dense<2.0> : tensor<1xf32>
+ %0 = tosa.equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+ %1 = tosa.greater %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+ %2 = tosa.greater_equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+ %3 = tosa.cast %lhs : (tensor<1xi32>) -> tensor<*xf32>
+ %4 = tosa.reciprocal %f : (tensor<1xf32>) -> tensor<*xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_unranked_identity_results
+func.func @try_fold_unranked_identity_results(%arg0: tensor<1xf32>) {
+ // CHECK: tosa.transpose
+ // CHECK: tosa.reverse
+ // CHECK: tosa.abs
+ // CHECK-NEXT: return
+ %0 = tosa.transpose %arg0 { perms = array<i32: 0> } : (tensor<1xf32>) -> tensor<*xf32>
+ %1 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<1xf32>) -> tensor<*xf32>
+ %3 = tosa.abs %arg0 : (tensor<1xf32>) -> tensor<*xf32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @fold_add_zero_rhs_f32
func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
@@ -1489,4 +1524,3 @@ func.func @test_concat_shape_total_rank9_shapes() -> !tosa.shape<9> {
return %abc : !tosa.shape<9>
}
-// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
index d95d267e8c907..711dfe4d2405e 100644
--- a/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
@@ -46,6 +46,25 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @transpose_nofold_unranked_result_not_reshape
+func.func @transpose_nofold_unranked_result_not_reshape(%arg0: tensor<6x7xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.transpose
+ %1 = tosa.transpose %arg0 { perms = array<i32: 1, 0> }: (tensor<6x7xf32>) -> tensor<*xf32>
+ return %1 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @reciprocal_nofold_unranked_result
+func.func @reciprocal_nofold_unranked_result() -> tensor<*xf32> {
+ %input = "tosa.const"() {values = dense<2.0> : tensor<6x7xf32>} : () -> tensor<6x7xf32>
+ // CHECK: tosa.reciprocal
+ %1 = tosa.reciprocal %input : (tensor<6x7xf32>) -> tensor<*xf32>
+ return %1 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @transpose_fold_splat
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {values = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
>From 44184f56a452d508a668a73714334eb7395ea65a Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 26 Mar 2026 19:17:48 +0800
Subject: [PATCH 2/4] [MLIR][tosa] Address comments
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 20 +++++++++----------
.../Dialect/Tosa/Transforms/TosaFolders.cpp | 4 +---
mlir/test/Dialect/Tosa/constant_folding.mlir | 5 +++++
3 files changed, 15 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b95e6819d6268..bee90bf0cc059 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -955,9 +955,7 @@ binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
if (!lhs || !rhs)
return {};
- // DenseElementsAttr::get needs a static shape. Result types may be unranked
- // (no RankedTensorType) or ranked-dynamic while operands are dense splats.
- if (!returnTy || !returnTy.hasStaticShape())
+ if (!returnTy.hasRank() || !returnTy.hasStaticShape())
return {};
const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
@@ -1008,7 +1006,7 @@ static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
if (!val)
return {};
- if (!returnTy || !returnTy.hasStaticShape())
+ if (!returnTy.hasRank() || !returnTy.hasStaticShape())
return {};
const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
@@ -1513,7 +1511,7 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
- auto resultTy = llvm::dyn_cast<ShapedType>(getType());
+ auto resultTy = llvm::cast<ShapedType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
@@ -1526,7 +1524,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
- auto resultTy = llvm::dyn_cast<ShapedType>(getType());
+ auto resultTy = llvm::cast<ShapedType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
@@ -1539,7 +1537,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
- auto resultTy = llvm::dyn_cast<ShapedType>(getType());
+ auto resultTy = llvm::cast<ShapedType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
@@ -1550,7 +1548,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
// If we are comparing an integer value to itself it is always true. We
// can not do this with float due to float values.
- if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
+ if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
resultTy.hasStaticShape() && lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
}
@@ -1571,7 +1569,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto inTy = llvm::cast<ShapedType>(getInput().getType());
auto outTy = llvm::cast<ShapedType>(getType());
- if (!outTy.hasStaticShape())
+ if (!outTy.hasRank() || !outTy.hasStaticShape())
return {};
auto inETy = inTy.getElementType();
auto outETy = outTy.getElementType();
@@ -1921,7 +1919,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
// Transposing splat values just means reshaping.
if (auto input =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
- if (input.isSplat() && resultTy.hasStaticShape() &&
+ if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
}
@@ -2020,7 +2018,7 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
return {};
auto shapeType = llvm::cast<ShapedType>(getType());
- if (!shapeType.hasStaticShape())
+ if (!shapeType.hasRank() || !shapeType.hasStaticShape())
return {};
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
auto floatVal = inputAttr.getSplatValue<APFloat>();
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 7f97f8ba0becd..d6961628afc9f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -242,7 +242,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
- if (!outputType.hasStaticShape())
+ if (!outputType.hasRank() || !outputType.hasStaticShape())
return failure();
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
@@ -271,8 +271,6 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
}
};
-/// Fold `tosa.reciprocal` into `tosa.const` when the operand is a dense float
-/// TOSA constant, types match, and `constantUnaryOpShouldBeFolded` allows it.
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 7a26b5475ffda..fa74d07d86165 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -35,6 +35,7 @@ func.func @try_fold_unranked_constant_results() {
// CHECK: tosa.greater_equal
// CHECK: tosa.cast
// CHECK: tosa.reciprocal
+ // CHECK: tosa.abs
// CHECK-NEXT: return
%lhs = arith.constant dense<1> : tensor<1xi32>
%rhs = arith.constant dense<2> : tensor<1xi32>
@@ -44,6 +45,7 @@ func.func @try_fold_unranked_constant_results() {
%2 = tosa.greater_equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
%3 = tosa.cast %lhs : (tensor<1xi32>) -> tensor<*xf32>
%4 = tosa.reciprocal %f : (tensor<1xf32>) -> tensor<*xf32>
+ %5 = tosa.abs %f : (tensor<1xf32>) -> tensor<*xf32>
return
}
@@ -54,10 +56,13 @@ func.func @try_fold_unranked_identity_results(%arg0: tensor<1xf32>) {
// CHECK: tosa.transpose
// CHECK: tosa.reverse
// CHECK: tosa.abs
+ // CHECK: tosa.negate
// CHECK-NEXT: return
+ %zp = arith.constant dense<0.0> : tensor<1xf32>
%0 = tosa.transpose %arg0 { perms = array<i32: 0> } : (tensor<1xf32>) -> tensor<*xf32>
%1 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<1xf32>) -> tensor<*xf32>
%3 = tosa.abs %arg0 : (tensor<1xf32>) -> tensor<*xf32>
+ %4 = tosa.negate %arg0, %zp, %zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
return
}
>From 00863be2dafd9db5c5db9bca9eb6eb6fbb27748e Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 26 Mar 2026 19:34:05 +0800
Subject: [PATCH 3/4] [mlir][tosa] Fix formatting
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index bee90bf0cc059..2ac83fca46028 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1759,15 +1759,16 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput1();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
- // Check if the reverse is a no-op
- // If the operand is a splat, the reverse is a no-op.
bool noOpReverse =
llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
+
// If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
- noOpReverse |= (operandTy.hasRank() && (operandTy.getRank() == 0 ||
- operandTy.getDimSize(axis) == 1));
+ noOpReverse |= operandTy.hasRank() &&
+ (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1);
+
if (noOpReverse)
return foldToInputIfTypeMatches(*this, operand);
+
return {};
}
>From 7f3f89bb9eb9d37d01652c3e5ac5785a8925b596 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 26 Mar 2026 23:32:56 +0800
Subject: [PATCH 4/4] [mlir][tosa] Fix canonicalization logic
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 16 ++++++----------
1 file changed, 6 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 2ac83fca46028..2b27573a60247 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1759,17 +1759,13 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput1();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
- bool noOpReverse =
- llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
-
// If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
- noOpReverse |= operandTy.hasRank() &&
- (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1);
-
- if (noOpReverse)
- return foldToInputIfTypeMatches(*this, operand);
-
- return {};
+ const bool isSplatInput =
+ llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
+ if (!operandTy.hasRank() ||
+ (!isSplatInput && operandTy.hasRank() && operandTy.getDimSize(axis) != 1))
+ return {};
+ return foldToInputIfTypeMatches(*this, operand);
}
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
More information about the Mlir-commits
mailing list