[Mlir-commits] [mlir] [Tosa] Disable tosa folder for non-int/float/index types (PR #71757)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 8 18:30:57 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

In order to fold, we need to create DenseElementsAttr, which does not support quantized element types. This patch adds tests for folding quntized element types and disable tosa folders where appropriate.

refactored canonicalize.mlir test to use --split-input-file

also fixed verifier for trait MulOperandsAndResultElementType for quantized element types

---

Patch is 30.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71757.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+3-6) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+16-2) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+237-9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a9bc3351f4cff05..5f7a6dfb62a3f16 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -55,11 +55,6 @@ class MulOperandsAndResultElementType
   static LogicalResult verifyTrait(Operation *op) {
     auto resElemType = getElementTypeOrSelf(op->getResult(0));
 
-    // In cases of floating point type, op requires the same element
-    // type for all operands and result.
-    if (llvm::isa<FloatType>(resElemType))
-      return impl::verifySameOperandsAndResultElementType(op);
-
     if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
       IntegerType lhsIntType =
           getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
@@ -78,7 +73,9 @@ class MulOperandsAndResultElementType
       return success();
     }
 
-    return failure();
+    // In cases of all other types, op requires the same element
+    // type for all operands and result.
+    return impl::verifySameOperandsAndResultElementType(op);
   }
 };
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7444f70a46e9355..46b726c779012e8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -489,6 +489,11 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
 
+  // Cannot create an ElementsAttr from non-int/float/index types
+  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
+      !rhsTy.getElementType().isIntOrIndexOrFloat())
+    return {};
+
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -514,6 +519,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
   if (lhsTy != rhsTy)
     return {};
 
+  // IntDivOp inputs must be integer type, no need to check for qunatized type
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -611,6 +617,11 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
 
+  // Cannot create an ElementsAttr from non-int/float/index types
+  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
+      !rhsTy.getElementType().isIntOrIndexOrFloat())
+    return {};
+
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -801,6 +812,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
+  // Cannot create an ElementsAttr from non-int/float/index types
+  if (!inputTy.getElementType().isIntOrIndexOrFloat())
+    return {};
+
   // reshape(const(x)) -> const(reshape-attr(x))
   if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     // Constants must have static shape.
@@ -936,13 +951,12 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
-  auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
   auto resultTy = llvm::cast<ShapedType>(getType());
 
   // Transposing splat values just means reshaping.
   if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     if (input.isSplat() && resultTy.hasStaticShape() &&
-        inputTy.getElementType() == resultTy.getElementType())
+        input.getType().getElementType() == resultTy.getElementType())
       return input.reshape(resultTy);
   }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca0580..6dc3d5fec09253f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
 func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
@@ -7,6 +7,8 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   return %0 : tensor<?x1xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @add_bcast_zero_int
 func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
   // CHECK-NOT: tosa.add
@@ -16,6 +18,8 @@ func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
   return %1 : tensor<4x2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @add_zero_int
 func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -25,6 +29,8 @@ func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @cast_fold
 func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -32,6 +38,8 @@ func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @cast_nofold
 func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: tosa.cast
@@ -39,6 +47,8 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   return %0 : tensor<?x1xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_i32_not_noop
 func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK: tosa.clamp
@@ -46,6 +56,8 @@ func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   return %0 : tensor<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f16_not_noop
 func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   // CHECK: tosa.clamp
@@ -53,6 +65,8 @@ func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   return %0 : tensor<4xf16>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f32_not_noop
 func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: tosa.clamp
@@ -60,6 +74,8 @@ func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f16_is_noop
 func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   // CHECK: return %arg0
@@ -69,6 +85,8 @@ func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   return %0 : tensor<4xf16>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_f32_is_noop
 func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: return %arg0
@@ -78,6 +96,8 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_int8_is_noop
 func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: return %arg0
@@ -86,6 +106,8 @@ func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   return %0 : tensor<4xi8>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_int16_is_noop
 func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
   // CHECK: return %arg0
@@ -94,6 +116,8 @@ func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
   return %0 : tensor<4xi16>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_uint8_is_noop
 func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
   // CHECK: return %arg0
@@ -102,6 +126,8 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
   return %0 : tensor<4xui8>
 }
 
+// -----
+
 // CHECK-LABEL: @clamp_twice_is_single_clamp
 func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
@@ -110,6 +136,8 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   return %1 : tensor<4xi8>
 }
 
+// -----
+
 // CHECK-LABEL: @concat_fold
 func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -117,6 +145,8 @@ func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @concat_fold_cast
 func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   // CHECK: %[[VAR0:.*]] = tensor.cast %arg0
@@ -125,6 +155,8 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   return %0 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @conv2d_stride_2
 func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: tosa.conv2d
@@ -134,6 +166,8 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32
   return %0 : tensor<4x10x10x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @conv2d_weight_2x2
 func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
   // CHECK: tosa.conv2d
@@ -143,6 +177,8 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf
   return %0 : tensor<4x10x10x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @depthwise_conv2d_stride_2
 func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: tosa.depthwise_conv2d
@@ -150,6 +186,8 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
   return %0 : tensor<4x10x10x6xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @depthwise_conv2d_weight_2x2
 func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: tosa.depthwise_conv2d
@@ -157,6 +195,8 @@ func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tens
   return %0 : tensor<4x10x10x6xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @max_pool2d_is_noop
 func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
   // CHECK-NOT: tosa.max_pool2d
@@ -165,6 +205,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
   return %0 : tensor<10x1x1x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_noop
 func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %arg0
@@ -173,6 +215,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_determine_val_i32
 func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
@@ -182,6 +226,8 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
   return %1 : tensor<?x?xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_determine_val_f32
 func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
@@ -191,6 +237,8 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @pad_determine_val_quant
 func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
@@ -200,6 +248,8 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
   return %1 : tensor<?x?xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_one_float
 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
@@ -209,6 +259,8 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   return %1 : tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_bcast_one_float
 func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
@@ -218,6 +270,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   return %1 : tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_one_int
 func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -227,6 +281,8 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @mul_zero_broadcast
 func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
@@ -240,6 +296,8 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -248,6 +306,8 @@ func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> t
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_true_value
 func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -257,6 +317,8 @@ func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_false_value
 func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -266,6 +328,8 @@ func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
   return %0 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @select_not_pred
 func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.logical_not %arg0 : (tensor<2x3xi1>) -> tensor<2x3xi1>
@@ -274,6 +338,8 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
   return %1 : tensor<2x3xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_all_fold
 func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -281,6 +347,8 @@ func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_all_nofold
 func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_all
@@ -288,6 +356,8 @@ func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_any_fold
 func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -295,6 +365,8 @@ func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_any_nofold
 func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_any
@@ -302,6 +374,8 @@ func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_max_fold
 func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -309,6 +383,8 @@ func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_max_nofold
 func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_max
@@ -316,6 +392,8 @@ func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_min_fold
 func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -323,6 +401,8 @@ func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_min_nofold
 func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_min
@@ -330,6 +410,8 @@ func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_prod_fold
 func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -337,6 +419,8 @@ func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_prod_nofold
 func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_prod
@@ -344,6 +428,8 @@ func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_sum_fold
 func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -351,6 +437,8 @@ func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reduce_sum_nofold
 func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: tosa.reduce_sum
@@ -358,6 +446,8 @@ func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize
 func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   // CHECK: return %arg0
@@ -365,6 +455,8 @@ func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   return %0 : tensor<?x10xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_double
 func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
   // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>}
@@ -374,6 +466,8 @@ func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf3
   return %1 : tensor<?x5xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const
 func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
   // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>}
@@ -383,6 +477,8 @@ func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
   return %1 : tensor<1x5xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const_dynamic
 func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
   // CHECK: tosa.reshape
@@ -391,6 +487,8 @@ func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
   return %1 : tensor<1x?xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const_splat
 func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi32>) {
   // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<10xi32>}
@@ -401,6 +499,8 @@ func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi3
   return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
 }
 
+// -----
+
 // CHECK-LABEL: @reshape_canonicalize_const_sparse
 func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
   // CHECK: tosa.reshape
@@ -409,23 +509,32 @@ func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32
   return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
 }
 
-// CHECK-LABEL: @reshape_canonicalize_quant
-func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
-  // CHECK{LITERAL}: "tosa.const"() <{value = dense<[[1, 2, 3]]> : tensor<1x3xi8>}> : () -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+// -----
+
+// CHECK-LABEL: @reshape_canonicalize_quant_nofold
+func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
+  // disabled folding for quantized element types
+  // CHECK{LITERAL}: "tosa.const"() <{value = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
+ ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71757


More information about the Mlir-commits mailing list