[Mlir-commits] [mlir] [mlir][tosa] Check for overflow in integer folders (PR #172695)

Luke Hutton llvmlistbot at llvm.org
Wed Dec 17 09:00:28 PST 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/172695

For these folders to be TOSA compliant, they need to check for overflow. This commit adds those checks, subsequently preventing folding if an overflow is detected.

This commit also fixes the greater/greater_equal folders to account for unsigned types.

Note: this change is currently dependent on https://github.com/llvm/llvm-project/pull/172691

>From 47596a59518feb29a3f0b5271e9f75df9bb07e86 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 16 Dec 2025 21:38:13 +0000
Subject: [PATCH 1/2] [mlir][tosa] Separate layerwise folding and simple folder
 tests (NFC)

This commit moves the 'simple' folder tests (invoked via
`--canonicalize`) away from other layerwise constant folding
tests (invoked via `--tosa-layerwise-constant-fold`) into a
separate test file to help reduce confusion.

Also rename the layerwise folding test file to reflect the
the pass name that they are invoked by.

Change-Id: I22bfa76480eddd8f850702986d79608d956e766a
---
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 510 ++++++++++++++++-
 ...mlir => tosa-layerwise-constant-fold.mlir} | 530 +-----------------
 2 files changed, 533 insertions(+), 507 deletions(-)
 rename mlir/test/Dialect/Tosa/{constant-op-fold.mlir => tosa-layerwise-constant-fold.mlir} (63%)

diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index d477a2479e913..bf6e1ad23bcb9 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --test-single-fold %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-single-fold %s | FileCheck %s
 
 // CHECK-LABEL: func @test_const
 func.func @test_const(%arg0 : index) -> tensor<4xi32> {
@@ -7,6 +7,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
   return %0 : tensor<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @test_const_i64
 func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
   // CHECK: tosa.const
@@ -14,6 +16,8 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
   return %0 : tensor<4xi64>
 }
 
+// -----
+
 // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
 func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
   // CHECK: tosa.equal
@@ -21,3 +25,507 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
   %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_rhs_f32
+func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %add = tosa.add %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %add : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_lhs_f32
+func.func @fold_add_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %add = tosa.add %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %add : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_rhs_i32
+func.func @fold_add_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %add = tosa.add %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %add : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_lhs_i32
+func.func @fold_add_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %add = tosa.add %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %add : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_i32
+func.func @fold_add_splat_i32() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32>
+  %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3> : tensor<10xi32>}
+  // CHECK: return %[[THREE]]
+  return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_f32
+func.func @fold_add_splat_f32() -> tensor<10xf32> {
+  %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %add = tosa.add %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3.000000e+00>
+  // CHECK: return %[[THREE]]
+  return %add : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_zero_lhs_i32
+func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+  %div = tosa.intdiv %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %[[ZERO]]
+  return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_one_rhs_i32
+func.func @fold_div_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %one = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+  %div = tosa.intdiv %arg0, %one : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_splat_i32
+func.func @fold_div_splat_i32() -> tensor<i32> {
+  %lhs = "tosa.const"() {values = dense<10> : tensor<i32>} : () -> tensor<i32>
+  %rhs = "tosa.const"() {values = dense<-3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-3>
+  %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %[[SPLAT]]
+  return %div : tensor<i32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @fold_mul_zero_rhs_f32
+func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_lhs_f32
+func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_rhs_i32
+func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+  %mul = tosa.mul %arg0, %zero, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_lhs_i32
+func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+  %mul = tosa.mul %zero, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_rhs_f32
+func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %one = "tosa.const"() {values = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_lhs_f32
+func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %one = "tosa.const"() {values = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_rhs_i32
+func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %one = "tosa.const"() {values = dense<64> : tensor<i32>} : () -> tensor<i32>
+  %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %one, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_lhs_i32
+func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %one = "tosa.const"() {values = dense<64> : tensor<i32>} : () -> tensor<i32>
+  %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_splat_i8
+func.func @fold_mul_splat_i8() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
+  %two = "tosa.const"() {values = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
+  %shift = "tosa.const"() {values = dense<3> : tensor<1xi8>} : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<68> : tensor<10xi32>}
+  // CHECK: return %[[THREE]]
+  return %mul : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_splat_f32
+func.func @fold_mul_splat_f32() -> tensor<10xf32> {
+  %one = "tosa.const"() {values = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<6.000000e+00> : tensor<10xf32>}
+  // CHECK: return %[[THREE]]
+  return %mul : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_zero_rhs_f32
+func.func @fold_sub_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %sub = tosa.sub %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %sub : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_zero_rhs_i32
+func.func @fold_sub_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %sub = tosa.sub %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %sub : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_i32
+func.func @fold_sub_splat_i32() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32>
+  %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<10xi32>}
+  // CHECK: return %[[THREE]]
+  return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_f32
+func.func @fold_sub_splat_f32() -> tensor<10xf32> {
+  %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %sub = tosa.sub %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<10xf32>}
+  // CHECK: return %[[THREE]]
+  return %sub : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_f32
+func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = tosa.greater %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = tosa.greater %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32
+func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {values = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
+  %false = tosa.greater %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %true = tosa.greater %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK: return %[[FALSE]], %[[TRUE]]
+  return %false, %true : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_eq_splat_f32
+func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = tosa.greater_equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = tosa.greater_equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_eq_splat_i32
+func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %true = tosa.greater_equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %false = tosa.greater_equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_splat_f32
+func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = tosa.equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = tosa.equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_splat_i32
+func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %true = tosa.equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %false = tosa.equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_i32
+func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
+  // CHECK: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  %0 = tosa.equal %arg0, %arg0 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK: return %[[TRUE]]
+  return %0 : tensor<10xi1>
+}
+
+// -----
+
+func.func @reshape_splat() -> tensor<6x5x4xi32> {
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<6x5x4xi32>}
+  %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+  %const = tosa.const_shape {values = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32>
+  // CHECK: return %[[SPLAT]]
+  return %reshape : tensor<6x5x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_splat
+func.func @slice_splat() -> tensor<1x1x1xi32> {
+  // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<42> : tensor<1x1x1xi32>}
+  %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+  %start = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %size = tosa.const_shape {values = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32>
+
+  // CHECK: return %[[SLICE]]
+  return %slice : tensor<1x1x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_singleton
+func.func @slice_singleton() -> tensor<1x1xi32> {
+  %splat = "tosa.const"() {values = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
+  // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<4> : tensor<1x1xi32>}
+  %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %size = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32>
+  // CHECK: return %[[SLICE]]
+  return %slice : tensor<1x1xi32>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_float
+func.func @cast_float_to_float() -> tensor<f16> {
+  %splat = "tosa.const"() {values = dense<42.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.200000e+01> : tensor<f16>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<f16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<f16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_float
+func.func @cast_int_to_float() -> tensor<f16> {
+  %splat = "tosa.const"() {values = dense<4> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.000000e+00> : tensor<f16>}
+  %cast = tosa.cast %splat : (tensor<i32>) -> tensor<f16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<f16>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_int
+func.func @cast_float_to_int() -> tensor<i16> {
+  %splat = "tosa.const"() {values = dense<-4.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor<i16>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_int_round
+func.func @cast_float_to_int_round() -> tensor<i16> {
+  %splat = "tosa.const"() {values = dense<-3.5> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor<i16>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_int_trunc
+func.func @cast_int_to_int_trunc() -> tensor<i16> {
+  %splat = "tosa.const"() {values = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor<i16>}
+  %cast = tosa.cast %splat : (tensor<i32>) -> tensor<i16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_int_sign
+func.func @cast_int_to_int_sign() -> tensor<i32> {
+  %splat = "tosa.const"() {values = dense<-1> : tensor<i16>} : () -> tensor<i16>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor<i32>}
+  %cast = tosa.cast %splat : (tensor<i16>) -> tensor<i32>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @reverse_splat
+func.func @reverse_splat() -> tensor<10xi32> {
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<10xi32>}
+  %splat = "tosa.const"() {values = dense<42> : tensor<10xi32>} : () -> tensor<10xi32>
+  %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32>
+  // CHECK: return %[[SPLAT]]
+  return %reverse : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reverse_length_one
+func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) {
+  %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+  %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+  // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32}
+  // CHECK: return %[[NOFOLD]], %arg0
+  return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32>
+}
+
+// -----
+
+// no_shift_op_reorder checks that %arg1 won't be reorder with %0
+// by the folder pass.
+// CHECK-LABEL: @no_shift_op_reorder
+func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) -> tensor<44x57xi32> {
+  %0 = "tosa.const"() {values = dense<1> : tensor<44x57xi16>} : () -> tensor<44x57xi16>
+  // CHECK: tosa.mul %arg0, %0, %arg1
+  %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32>
+  return %1 : tensor<44x57xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
similarity index 63%
rename from mlir/test/Dialect/Tosa/constant-op-fold.mlir
rename to mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
index b1fbcdcc53e2f..d95d267e8c907 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
@@ -1,6 +1,4 @@
 // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
-
-
 // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
 
 // CHECK-LABEL: @armax_fold_dim_size_1
@@ -10,6 +8,8 @@ func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1
 func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> {
   // CHECK: tosa.argmax
@@ -17,6 +17,8 @@ func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) ->
   return %0 : tensor<?x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_fold
 func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -24,6 +26,8 @@ func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   return %1 : tensor<3x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_nofold
 func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
   // CHECK: tosa.transpose
@@ -31,6 +35,8 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
   return %1 : tensor<3x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_nofold_shape
 func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
   // CHECK: tosa.transpose
@@ -38,6 +44,8 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_fold_splat
 func.func @transpose_fold_splat() -> tensor<3x2xf32> {
   %input = "tosa.const"() {values = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -48,6 +56,8 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> {
   return %1 : tensor<3x2xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_fold_2d_float
 func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
   %input = "tosa.const"() {values = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -58,6 +68,8 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
   return %1 : tensor<3x2xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_fold_2d_bool
 func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> {
   %input = "tosa.const"() {values = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -68,6 +80,8 @@ func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> {
   return %1 : tensor<3x2xi1>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_fold_4d_int
 func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
   %input = "tosa.const"() {values = dense<[[
@@ -85,6 +99,8 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
   return %1 : tensor<3x1x4x2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_nofold_non_cst_input
 func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
   // CHECK: tosa.transpose
@@ -92,6 +108,8 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2
   return %1 : tensor<3x2xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_nofold_multi_users
 func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
   %input = "tosa.const"() {values = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -100,6 +118,8 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
   return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_nofold_quantized_types
 func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
   %input = "tosa.const"() {values = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
@@ -108,6 +128,8 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
   return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_fold_dense_resource
 func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> {
   %0 = "tosa.const"() <{values = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
@@ -124,498 +146,6 @@ func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> {
   }
 #-}
 
-// -----
-
-// CHECK-LABEL: @fold_add_zero_rhs_f32
-func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  %add = tosa.add %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK: return %arg0
-  return %add : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_add_zero_lhs_f32
-func.func @fold_add_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  %add = tosa.add %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK: return %arg0
-  return %add : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_add_zero_rhs_i32
-func.func @fold_add_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  %add = tosa.add %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: return %arg0
-  return %add : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_add_zero_lhs_i32
-func.func @fold_add_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  %add = tosa.add %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: return %arg0
-  return %add : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_add_splat_i32
-func.func @fold_add_splat_i32() -> tensor<10xi32> {
-  %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
-  %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32>
-  %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3> : tensor<10xi32>}
-  // CHECK: return %[[THREE]]
-  return %add : tensor<10xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_add_splat_f32
-func.func @fold_add_splat_f32() -> tensor<10xf32> {
-  %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %add = tosa.add %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3.000000e+00>
-  // CHECK: return %[[THREE]]
-  return %add : tensor<10xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_div_zero_lhs_i32
-func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
-  %div = tosa.intdiv %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: return %[[ZERO]]
-  return %div : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_div_one_rhs_i32
-func.func @fold_div_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %one = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
-  %div = tosa.intdiv %arg0, %one : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: return %arg0
-  return %div : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_div_splat_i32
-func.func @fold_div_splat_i32() -> tensor<i32> {
-  %lhs = "tosa.const"() {values = dense<10> : tensor<i32>} : () -> tensor<i32>
-  %rhs = "tosa.const"() {values = dense<-3> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-3>
-  %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: return %[[SPLAT]]
-  return %div : tensor<i32>
-}
-
-// -----
-
-
-// CHECK-LABEL: @fold_mul_zero_rhs_f32
-func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
-  // CHECK: return %[[ZERO]]
-  return %mul : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_zero_lhs_f32
-func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
-  // CHECK: return %[[ZERO]]
-  return %mul : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_zero_rhs_i32
-func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
-  %mul = tosa.mul %arg0, %zero, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
-  // CHECK: return %[[ZERO]]
-  return %mul : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_zero_lhs_i32
-func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
-  %mul = tosa.mul %zero, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
-  // CHECK: return %[[ZERO]]
-  return %mul : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_one_rhs_f32
-func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %one = "tosa.const"() {values = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
-  // CHECK: return %arg0
-  return %mul : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_one_lhs_f32
-func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %one = "tosa.const"() {values = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
-  // CHECK: return %arg0
-  return %mul : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_one_rhs_i32
-func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %one = "tosa.const"() {values = dense<64> : tensor<i32>} : () -> tensor<i32>
-  %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8>
-  %mul = tosa.mul %arg0, %one, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
-  // CHECK: return %arg0
-  return %mul : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_one_lhs_i32
-func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %one = "tosa.const"() {values = dense<64> : tensor<i32>} : () -> tensor<i32>
-  %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8>
-  %mul = tosa.mul %one, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
-  // CHECK: return %arg0
-  return %mul : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_splat_i8
-func.func @fold_mul_splat_i8() -> tensor<10xi32> {
-  %one = "tosa.const"() {values = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
-  %two = "tosa.const"() {values = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
-  %shift = "tosa.const"() {values = dense<3> : tensor<1xi8>} : () -> tensor<1xi8>
-  %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<68> : tensor<10xi32>}
-  // CHECK: return %[[THREE]]
-  return %mul : tensor<10xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_mul_splat_f32
-func.func @fold_mul_splat_f32() -> tensor<10xf32> {
-  %one = "tosa.const"() {values = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<6.000000e+00> : tensor<10xf32>}
-  // CHECK: return %[[THREE]]
-  return %mul : tensor<10xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_sub_zero_rhs_f32
-func.func @fold_sub_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
-  %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  %sub = tosa.sub %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK: return %arg0
-  return %sub : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_sub_zero_rhs_i32
-func.func @fold_sub_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
-  %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  %sub = tosa.sub %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: return %arg0
-  return %sub : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_sub_splat_i32
-func.func @fold_sub_splat_i32() -> tensor<10xi32> {
-  %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
-  %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32>
-  %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<10xi32>}
-  // CHECK: return %[[THREE]]
-  return %sub : tensor<10xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_sub_splat_f32
-func.func @fold_sub_splat_f32() -> tensor<10xf32> {
-  %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %sub = tosa.sub %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<10xf32>}
-  // CHECK: return %[[THREE]]
-  return %sub : tensor<10xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_f32
-func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
-  %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %1 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %true = tosa.greater %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  %false = tosa.greater %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[TRUE]], %[[FALSE]]
-  return %true, %false : tensor<10xi1>, tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_i32
-func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
-  %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
-  %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %3 = "tosa.const"() {values = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
-  %false = tosa.greater %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  %true = tosa.greater %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
-  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  // CHECK: return %[[FALSE]], %[[TRUE]]
-  return %false, %true : tensor<10xi1>, tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_eq_splat_f32
-func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
-  %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %true = tosa.greater_equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  %false = tosa.greater_equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[TRUE]], %[[FALSE]]
-  return %true, %false : tensor<10xi1>, tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_eq_splat_i32
-func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
-  %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
-  %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %true = tosa.greater_equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  %false = tosa.greater_equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[TRUE]], %[[FALSE]]
-  return %true, %false : tensor<10xi1>, tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_eq_splat_f32
-func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
-  %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %true = tosa.equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  %false = tosa.equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[TRUE]], %[[FALSE]]
-  return %true, %false : tensor<10xi1>, tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_eq_splat_i32
-func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
-  %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
-  %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %true = tosa.equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  %false = tosa.equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[TRUE]], %[[FALSE]]
-  return %true, %false : tensor<10xi1>, tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_eq_i32
-func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
-  // CHECK: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
-  %0 = tosa.equal %arg0, %arg0 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK: return %[[TRUE]]
-  return %0 : tensor<10xi1>
-}
-
-// -----
-
-func.func @reshape_splat() -> tensor<6x5x4xi32> {
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<6x5x4xi32>}
-  %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
-  %const = tosa.const_shape {values = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32>
-  // CHECK: return %[[SPLAT]]
-  return %reshape : tensor<6x5x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_splat
-func.func @slice_splat() -> tensor<1x1x1xi32> {
-  // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<42> : tensor<1x1x1xi32>}
-  %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
-  %start = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  %size = tosa.const_shape {values = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32>
-
-  // CHECK: return %[[SLICE]]
-  return %slice : tensor<1x1x1xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_singleton
-func.func @slice_singleton() -> tensor<1x1xi32> {
-  %splat = "tosa.const"() {values = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
-  // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<4> : tensor<1x1xi32>}
-  %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %size = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32>
-  // CHECK: return %[[SLICE]]
-  return %slice : tensor<1x1xi32>
-}
-
-// -----
-
-// CHECK: func.func @cast_float_to_float
-func.func @cast_float_to_float() -> tensor<f16> {
-  %splat = "tosa.const"() {values = dense<42.0> : tensor<f32>} : () -> tensor<f32>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.200000e+01> : tensor<f16>}
-  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<f16>
-  // CHECK: return %[[SPLAT]]
-  return %cast : tensor<f16>
-}
-
-// -----
-
-// CHECK: func.func @cast_int_to_float
-func.func @cast_int_to_float() -> tensor<f16> {
-  %splat = "tosa.const"() {values = dense<4> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.000000e+00> : tensor<f16>}
-  %cast = tosa.cast %splat : (tensor<i32>) -> tensor<f16>
-  // CHECK: return %[[SPLAT]]
-  return %cast : tensor<f16>
-}
-
-// -----
-
-// CHECK: func.func @cast_float_to_int
-func.func @cast_float_to_int() -> tensor<i16> {
-  %splat = "tosa.const"() {values = dense<-4.0> : tensor<f32>} : () -> tensor<f32>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor<i16>}
-  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i16>
-  // CHECK: return %[[SPLAT]]
-  return %cast : tensor<i16>
-}
-
-// -----
-
-// CHECK: func.func @cast_float_to_int_round
-func.func @cast_float_to_int_round() -> tensor<i16> {
-  %splat = "tosa.const"() {values = dense<-3.5> : tensor<f32>} : () -> tensor<f32>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor<i16>}
-  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i16>
-  // CHECK: return %[[SPLAT]]
-  return %cast : tensor<i16>
-}
-
-// -----
-
-// CHECK: func.func @cast_int_to_int_trunc
-func.func @cast_int_to_int_trunc() -> tensor<i16> {
-  %splat = "tosa.const"() {values = dense<-1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor<i16>}
-  %cast = tosa.cast %splat : (tensor<i32>) -> tensor<i16>
-  // CHECK: return %[[SPLAT]]
-  return %cast : tensor<i16>
-}
-
-// -----
-
-// CHECK: func.func @cast_int_to_int_sign
-func.func @cast_int_to_int_sign() -> tensor<i32> {
-  %splat = "tosa.const"() {values = dense<-1> : tensor<i16>} : () -> tensor<i16>
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor<i32>}
-  %cast = tosa.cast %splat : (tensor<i16>) -> tensor<i32>
-  // CHECK: return %[[SPLAT]]
-  return %cast : tensor<i32>
-}
-
-// -----
-
-// CHECK-LABEL: @reverse_splat
-func.func @reverse_splat() -> tensor<10xi32> {
-  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<10xi32>}
-  %splat = "tosa.const"() {values = dense<42> : tensor<10xi32>} : () -> tensor<10xi32>
-  %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32>
-  // CHECK: return %[[SPLAT]]
-  return %reverse : tensor<10xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @reverse_length_one
-func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) {
-  %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
-  %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
-  // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32}
-  // CHECK: return %[[NOFOLD]], %arg0
-  return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32>
-}
-
 // -----
 
   func.func @reduce_sum_constant() -> tensor<1x3xi32> {
@@ -1172,15 +702,3 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
   %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %res1 : tensor<2x3xi32>
 }
-
-// -----
-
-// no_shift_op_reorder checks that %arg1 won't be reorder with %0
-// by the folder pass.
-// CHECK-LABEL: @no_shift_op_reorder
-func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) -> tensor<44x57xi32> {
-  %0 = "tosa.const"() {values = dense<1> : tensor<44x57xi16>} : () -> tensor<44x57xi16>
-  // CHECK: tosa.mul %arg0, %0, %arg1
-  %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32>
-  return %1 : tensor<44x57xi32>
-}

>From cc4a8d82434a881cc32eb4d8da1306c5a2b7dcd1 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 17 Dec 2025 13:28:15 +0000
Subject: [PATCH 2/2] [mlir][tosa] Check for overflow in integer folders

For these folders to be TOSA compliant, they need to check
for overflow. This commit adds those checks, subsequently
preventing folding if an overflow is detected.

This commit also fixes the greater/greater_equal folders
to account for unsigned types.

Change-Id: I2b5a5b92fb840d6c34a1f2faa18ae68a20d0ecdf
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 177 +++++++++++++-----
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 121 ++++++++++++
 2 files changed, 246 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..3e9d803a916a9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -889,33 +889,141 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Operator Folders.
 //===----------------------------------------------------------------------===//
 
-template <typename IntFolder, typename FloatFolder>
+template <typename Folder>
 static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
                                       DenseElementsAttr rhs,
                                       RankedTensorType returnTy) {
   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+    const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+    const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
     if (lETy != rETy)
       return {};
 
-    if (llvm::isa<IntegerType>(lETy)) {
-      APInt l = lhs.getSplatValue<APInt>();
-      APInt r = rhs.getSplatValue<APInt>();
-      auto result = IntFolder()(l, r);
-      return DenseElementsAttr::get(returnTy, result);
+    if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+      const APInt l = lhs.getSplatValue<APInt>();
+      const APInt r = rhs.getSplatValue<APInt>();
+      const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
 
     if (llvm::isa<FloatType>(lETy)) {
-      APFloat l = lhs.getSplatValue<APFloat>();
-      APFloat r = rhs.getSplatValue<APFloat>();
-      auto result = FloatFolder()(l, r);
-      return DenseElementsAttr::get(returnTy, result);
+      const APFloat l = lhs.getSplatValue<APFloat>();
+      const APFloat r = rhs.getSplatValue<APFloat>();
+      const auto maybeResult = Folder::fold(l, r);
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
   }
 
   return {};
 }
+struct AddFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    const unsigned originalWidth = lhs.getBitWidth();
+
+    APInt lhs64, rhs64;
+    if (isUnsigned) {
+      lhs64 = lhs.zext(64);
+      rhs64 = rhs.zext(64);
+
+      // Check for overflow
+      const APInt max = APInt::getMaxValue(originalWidth).zext(64);
+      if (lhs64.ugt(max - rhs64))
+        return failure();
+    } else {
+      lhs64 = lhs.sext(64);
+      rhs64 = rhs.sext(64);
+
+      // Check for overflow
+      const APInt zero = APInt::getZero(64);
+      const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64);
+      const APInt min = APInt::getSignedMinValue(originalWidth).sext(64);
+      if ((rhs64.sgt(zero) && lhs64.sgt(max - rhs64)) ||
+          (rhs64.slt(zero) && lhs64.slt(min - rhs64)))
+        return failure();
+    }
+
+    const APInt result64 = lhs64 + rhs64;
+    return result64.trunc(originalWidth);
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs + rhs;
+  }
+};
+
+struct SubFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    const unsigned originalWidth = lhs.getBitWidth();
+
+    APInt lhs64, rhs64;
+    if (isUnsigned) {
+      lhs64 = lhs.zext(64);
+      rhs64 = rhs.zext(64);
+
+      // Check for overflow
+      const APInt max = APInt::getMaxValue(originalWidth).zext(64);
+      if (lhs64.ult(rhs64))
+        return failure();
+    } else {
+      lhs64 = lhs.sext(64);
+      rhs64 = rhs.sext(64);
+
+      // Check for overflow
+      const APInt zero = APInt::getZero(64);
+      const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64);
+      const APInt min = APInt::getSignedMinValue(originalWidth).sext(64);
+      if ((rhs64.sgt(zero) && lhs64.slt(min + rhs64)) ||
+          (rhs64.slt(zero) && lhs64.sgt(max + rhs64)))
+        return failure();
+    }
+
+    const APInt result64 = lhs64 - rhs64;
+    return result64.trunc(originalWidth);
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs - rhs;
+  }
+};
+
+struct FoldGreaterAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
+  }
+
+  static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+    return APInt(1, lhs > rhs);
+  }
+};
+
+struct FoldGreaterEqualAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
+  }
+
+  static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+    return APInt(1, lhs >= rhs);
+  }
+};
+
+struct FoldEqualAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    return APInt(1, lhs == rhs);
+  }
+
+  static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+    return APInt(1, lhs == rhs);
+  }
+};
 
 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
   if (llvm::isa<FloatType>(elemType))
@@ -963,8 +1071,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
-                                                            resultTy);
+  return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -1145,38 +1252,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
-namespace {
-template <typename Cmp>
-struct ComparisonFold {
-  ComparisonFold() = default;
-  APInt operator()(const APInt &l, const APInt &r) {
-    return APInt(1, Cmp()(l, r));
-  }
-
-  APInt operator()(const APFloat &l, const APFloat &r) {
-    return APInt(1, Cmp()(l, r));
-  }
-};
-
-struct APIntFoldGreater {
-  APIntFoldGreater() = default;
-  APInt operator()(const APInt &l, const APInt &r) {
-    return APInt(1, l.sgt(r));
-  }
-};
-
-struct APIntFoldGreaterEqual {
-  APIntFoldGreaterEqual() = default;
-  APInt operator()(const APInt &l, const APInt &r) {
-    return APInt(1, l.sge(r));
-  }
-};
-} // namespace
-
 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
   auto lhsAttr =
@@ -1187,8 +1265,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
-      lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1201,9 +1278,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreaterEqual,
-                      ComparisonFold<std::greater_equal<APFloat>>>(
-      lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1226,9 +1301,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
-                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index bf6e1ad23bcb9..0922d6d2ee6fb 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -92,6 +92,50 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_add_splat_i32_positive_overflow
+func.func @fold_add_splat_i32_positive_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.add
+  %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_i32_negative_overflow
+func.func @fold_add_splat_i32_negative_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.add
+  %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8
+func.func @fold_add_splat_ui8() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: "tosa.const"() <{values = dense<255> : tensor<10xui8>}> : () -> tensor<10xui8>
+  %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %add : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8_overflow
+func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<2> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: tosa.add
+  %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %add : tensor<10xui8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_div_zero_lhs_i32
 func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -288,6 +332,50 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_sub_splat_i32_positive_overflow
+func.func @fold_sub_splat_i32_positive_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.sub
+  %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_i32_negative_overflow
+func.func @fold_sub_splat_i32_negative_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.sub
+  %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8
+func.func @fold_sub_splat_ui8() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: "tosa.const"() <{values = dense<2> : tensor<10xui8>}> : () -> tensor<10xui8>
+  %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %sub : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8_overflow
+func.func @fold_sub_splat_ui8_overflow() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: tosa.sub
+  %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %sub : tensor<10xui8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_greater_splat_f32
 func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -320,6 +408,23 @@ func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
 
 // -----
 
+// CHECK-LABEL: @fold_greater_splat_ui8
+func.func @fold_greater_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %1 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %2 = "tosa.const"() {values = dense<246> : tensor<10xui8>} : () -> tensor<10xui8>
+  %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+  %true = tosa.greater %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  %false = tosa.greater %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  %false2 = tosa.greater %0, %2 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]], %[[FALSE]]
+  return %true, %false, %false2 : tensor<10xi1>, tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_greater_eq_splat_f32
 func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -352,6 +457,22 @@ func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
 
 // -----
 
+// CHECK-LABEL: @fold_greater_eq_splat_ui8
+func.func @fold_greater_eq_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %1 = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+  %2 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+  %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+  %true = tosa.greater_equal %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  %false = tosa.greater_equal %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_eq_splat_f32
 func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>



More information about the Mlir-commits mailing list