[Mlir-commits] [mlir] [Tosa] Disable tosa folder for non-int/float/index types (PR #71757)

Tai Ly llvmlistbot at llvm.org
Mon Jun 10 13:34:14 PDT 2024


https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/71757

>From a8b0486d137a561bb3a90d6146c5c5f254e6c0da Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Sat, 28 Oct 2023 21:01:09 +0000
Subject: [PATCH] [Tosa] Disable tosa folder for non-int/float/index types

In order to fold, we need to create DenseElementsAttr, which
does not support quantized element types. This patch adds tests
for folding quntized element types and disable tosa folders
where appropriate.

refactored canonicalize.mlir test to use --split-input-file

also fixed verifier for trait MulOperandsAndResultElementType
for quantized element types

Signed-off-by: Tai Ly <tai.ly at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |   4 +-
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  18 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 248 +++++++++++++++++-
 3 files changed, 258 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index ec3c2cb011c35..7ed89bff474a2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -78,7 +78,9 @@ class MulOperandsAndResultElementType
       return success();
     }
 
-    return failure();
+    // In cases of all other types, op requires the same element
+    // type for all operands and result.
+    return impl::verifySameOperandsAndResultElementType(op);
   }
 };
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e4e5fe3d1db30..8687be075ea67 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -491,6 +491,11 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
 
+  // Cannot create an ElementsAttr from non-int/float/index types
+  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
+      !rhsTy.getElementType().isIntOrIndexOrFloat())
+    return {};
+
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -529,6 +534,7 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
   if (lhsTy != rhsTy)
     return {};
 
+  // IntDivOp inputs must be integer type, no need to check for quantized type
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -626,6 +632,11 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
 
+  // Cannot create an ElementsAttr from non-int/float/index types
+  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
+      !rhsTy.getElementType().isIntOrIndexOrFloat())
+    return {};
+
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -821,6 +832,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
+  // Cannot create an ElementsAttr from non-int/float/index types
+  if (!inputTy.getElementType().isIntOrIndexOrFloat())
+    return {};
+
   // reshape(const(x)) -> const(reshape-attr(x))
   if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     // Constants must have static shape.
@@ -956,13 +971,12 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
-  auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
   auto resultTy = llvm::cast<ShapedType>(getType());
 
   // Transposing splat values just means reshaping.
   if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     if (input.isSplat() && resultTy.hasStaticShape() &&
-        inputTy.getElementType() == resultTy.getElementType())
+        input.getType().getElementType() == resultTy.getElementType())
       return input.reshape(resultTy);
   }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6eac759a08364..c9fc43902eea6 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
 func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
@@ -7,6 +7,8 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   return %0 : tensor<?x1xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @add_bcast_zero_int
 func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
   // CHECK-NOT: tosa.add
@@ -16,6 +18,8 @@ func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
   return %1 : tensor<4x2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @add_zero_int
 func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -25,6 +29,8 @@ func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @cast_fold
 func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -32,6 +38,8 @@ func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @cast_nofold
 func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: tosa.cast
@@ -39,6 +47,8 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   return %0 : tensor<?x1xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_i32_not_noop
 func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK: tosa.clamp
@@ -46,6 +56,8 @@ func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   return %0 : tensor<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f16_not_noop
 func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   // CHECK: tosa.clamp
@@ -53,6 +65,8 @@ func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   return %0 : tensor<4xf16>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f32_not_noop
 func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: tosa.clamp
@@ -60,6 +74,8 @@ func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f16_is_noop
 func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   // CHECK: return %arg0
@@ -69,6 +85,8 @@ func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   return %0 : tensor<4xf16>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f32_is_noop
 func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: return %arg0
@@ -78,6 +96,8 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_int8_is_noop
 func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: return %arg0
@@ -86,6 +106,8 @@ func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   return %0 : tensor<4xi8>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_int16_is_noop
 func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
   // CHECK: return %arg0
@@ -94,6 +116,8 @@ func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
   return %0 : tensor<4xi16>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_uint8_is_noop
 func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
   // CHECK: return %arg0
@@ -102,6 +126,8 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
   return %0 : tensor<4xui8>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_twice_is_single_clamp
 func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
@@ -110,6 +136,8 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   return %1 : tensor<4xi8>
 }
 
+// -----
+
 // CHECK-LABEL: @concat_fold
 func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -117,6 +145,8 @@ func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @concat_fold_cast
 func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   // CHECK: %[[VAR0:.*]] = tensor.cast %arg0
@@ -125,6 +155,8 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   return %0 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @conv2d_stride_2
 func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: tosa.conv2d
@@ -134,6 +166,8 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32
   return %0 : tensor<4x10x10x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @conv2d_weight_2x2
 func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
   // CHECK: tosa.conv2d
@@ -143,6 +177,8 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf
   return %0 : tensor<4x10x10x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @depthwise_conv2d_stride_2
 func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: tosa.depthwise_conv2d
@@ -150,6 +186,8 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
   return %0 : tensor<4x10x10x6xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @depthwise_conv2d_weight_2x2
 func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: tosa.depthwise_conv2d
@@ -157,6 +195,8 @@ func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tens
   return %0 : tensor<4x10x10x6xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @max_pool2d_is_noop
 func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
   // CHECK-NOT: tosa.max_pool2d
@@ -165,6 +205,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
   return %0 : tensor<10x1x1x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_noop
 func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %arg0
@@ -173,6 +215,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_determine_val_i32
 func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
@@ -182,6 +226,8 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
   return %1 : tensor<?x?xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_determine_val_f32
 func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
@@ -191,6 +237,8 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_determine_val_quant
 func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
@@ -200,6 +248,8 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
   return %1 : tensor<?x?xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_one_float
 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
@@ -209,6 +259,8 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   return %1 : tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_bcast_one_float
 func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
@@ -218,6 +270,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   return %1 : tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_one_int
 func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -227,6 +281,8 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_zero_broadcast
 func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
@@ -240,6 +296,8 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -248,6 +306,8 @@ func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> t
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_true_value
 func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -257,6 +317,8 @@ func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_false_value
 func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -266,6 +328,8 @@ func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_not_pred
 func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.logical_not %arg0 : (tensor<2x3xi1>) -> tensor<2x3xi1>
@@ -274,6 +338,8 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
   return %1 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_all_fold
 func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -281,6 +347,8 @@ func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_all_nofold
 func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_all
@@ -288,6 +356,8 @@ func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_any_fold
 func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -295,6 +365,8 @@ func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_any_nofold
 func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_any
@@ -302,6 +374,8 @@ func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_max_fold
 func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -309,6 +383,8 @@ func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_max_nofold
 func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_max
@@ -316,6 +392,8 @@ func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_min_fold
 func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -323,6 +401,8 @@ func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_min_nofold
 func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_min
@@ -330,6 +410,8 @@ func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_prod_fold
 func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -337,6 +419,8 @@ func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_prod_nofold
 func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_prod
@@ -344,6 +428,8 @@ func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_sum_fold
 func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -351,6 +437,8 @@ func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_sum_nofold
 func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_sum
@@ -358,6 +446,8 @@ func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize
 func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   // CHECK: return %arg0
@@ -365,6 +455,8 @@ func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   return %0 : tensor<?x10xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_dyn_nofold
 func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor<?x?x10xf32>) -> tensor<?x?x10xf32> {
   // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
@@ -373,6 +465,8 @@ func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor<?x?x10xf32>) -> tensor<
   return %0 : tensor<?x?x10xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_double
 func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
   // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>}
@@ -382,6 +476,8 @@ func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf3
   return %1 : tensor<?x5xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const
 func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
   // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>}
@@ -391,6 +487,8 @@ func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
   return %1 : tensor<1x5xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const_dynamic
 func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
   // CHECK: tosa.reshape
@@ -399,6 +497,8 @@ func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
   return %1 : tensor<1x?xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const_splat
 func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi32>) {
   // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<10xi32>}
@@ -409,6 +509,8 @@ func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi3
   return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const_sparse
 func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
   // CHECK: tosa.reshape
@@ -417,23 +519,32 @@ func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32
   return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
 }
 
-// CHECK-LABEL: @reshape_canonicalize_quant
-func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
-  // CHECK{LITERAL}: "tosa.const"() <{value = dense<[[1, 2, 3]]> : tensor<1x3xi8>}> : () -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+// -----
+
+// CHECK-LABEL: @reshape_canonicalize_quant_nofold
+func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
+  // disabled folding for quantized element types
+  // CHECK{LITERAL}: "tosa.const"() <{value = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
+  // CHECK{LITERAL}: tosa.reshape %0 {new_shape = array<i64: 1, 3>} : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
   %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
   %1 = tosa.reshape %0 {new_shape = array<i64: 1, 3>} : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
   return %1 :  tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_canonicalize_strip_quant
-func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3xi8>) {
-  // CHECK: "tosa.const"() <{value = dense<0> : tensor<2x1x3xi8>}> : () -> tensor<2x1x3xi8>
+func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
+  // CHECK: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
+  // CHECK: tosa.reshape %0 {new_shape = array<i64: 2, 1, 3>} : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
   %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
   %0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
-  %1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3xi8>
-  return %1 :  tensor<2x1x3xi8>
+  %1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+  return %1 :  tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
 }
 
+// -----
+
 // CHECK-LABEL: @slice_fold
 func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -441,6 +552,8 @@ func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   return %0 : tensor<3x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @slice_nofold
 func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   // CHECK: tosa.slice
@@ -448,6 +561,8 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   return %0 : tensor<?x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @tile_fold
 func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -455,6 +570,8 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   return %0 : tensor<3x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @tile_nofold
 func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   // CHECK: tosa.tile
@@ -462,6 +579,8 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   return %0 : tensor<3x8xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_no_op
 func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
   // CHECK: return %arg0
@@ -471,6 +590,8 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
   return %1 : tensor<3x4x5x6xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @transpose_is_reshape
 func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
   // CHECK: tosa.reshape %arg0 {new_shape = array<i64: 1, 4, 1, 5>} : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
@@ -479,6 +600,8 @@ func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf3
   return %0 : tensor<1x4x1x5xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @single_bit_reshape
 // https://github.com/llvm/llvm-project/issues/55440
 func.func @single_bit_reshape() -> tensor<1xi1> {
@@ -603,7 +726,7 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 // -----
 
 // CHECK-LABEL: @fold_reduce_rank_zero
-func.func nested @fold_reduce_rank_zero() {
+func.func @fold_reduce_rank_zero() {
   // CHECK-NOT: tosa.reduce_min
   // CHECK-NOT: tosa.reverse
   %0 = tensor.empty() : tensor<i32>
@@ -622,6 +745,113 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   return %1 : tensor<i32>
 }
 
+// -----
+
+// CHECK-LABEL: @reshape_quant_nofold
+// check that segfault is fixed
+func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
+   %0 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = tosa.reshape %0 {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i32: 30>} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1xi32>
+   return %2 : tensor<1x1x1x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_quant_nofold
+// check that segfault is fixed
+func.func @add_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   %0 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = tosa.add %0, %0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @sub_quant_nofold
+// check that segfault is fixed
+func.func @sub_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   %0 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = tosa.sub %0, %0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @greater_quant_fold
+func.func @greater_quant_fold() -> tensor<i1> {
+   %0 = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   // CHECK: "tosa.const"() <{value = dense<false>
+   %2 = "tosa.greater"(%0, %0) : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<i1>
+   return %2 : tensor<i1>
+}
+
+// -----
+
+// CHECK-LABEL: @greater_equal_quant_fold
+func.func @greater_equal_quant_fold() -> tensor<i1> {
+   %0 = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   // CHECK: "tosa.const"() <{value = dense<true>
+   %2 = "tosa.greater_equal"(%0, %0) : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<i1>
+   return %2 : tensor<i1>
+}
+
+// -----
+
+// CHECK-LABEL: @equal_quant_fold
+func.func @equal_quant_fold() -> tensor<i1> {
+   %0 = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   // CHECK: "tosa.const"() <{value = dense<true>
+   %2 = "tosa.equal"(%0, %0) : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<i1>
+   return %2 : tensor<i1>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_quant_nofold
+func.func @cast_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:3>> {
+  // CHECK: tosa.cast
+   %0 = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = "tosa.cast"(%0) : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:3>>
+   return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:3>>
+}
+
+// -----
+
+// CHECK-LABEL: @reverse_quant_fold
+func.func @reverse_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   // CHECK: %[[CST:.*]] = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   // CHECK: return %[[CST]]
+   %0 = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @select_quant_fold
+func.func @select_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   // CHECK: %[[CONST_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   // CHECK: return %[[CONST_0]]
+   %0 = "tosa.const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+   %1 = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %2 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %3 = "tosa.select"(%0, %1, %2) : (tensor<i1>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %3 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_quant_nofold
+func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   // CHECK: tosa.mul
+   %0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+}
+
+
 // -----
 
 // CHECK-LABEL: @fold_reciprocal



More information about the Mlir-commits mailing list