[Mlir-commits] [mlir] 5ac9d66 - [DenseElementsAttr] Teach isValidRawBuffer that 1-elt values are splats.

Chris Lattner llvmlistbot at llvm.org
Sat May 14 03:49:52 PDT 2022


Author: Chris Lattner
Date: 2022-05-14T11:49:43+01:00
New Revision: 5ac9d662093ddd880f9059f84b17296a882e7985

URL: https://github.com/llvm/llvm-project/commit/5ac9d662093ddd880f9059f84b17296a882e7985
DIFF: https://github.com/llvm/llvm-project/commit/5ac9d662093ddd880f9059f84b17296a882e7985.diff

LOG: [DenseElementsAttr] Teach isValidRawBuffer that 1-elt values are splats.

We want getRaw() on tensors with i1 element type with a zero or 1 value
to be treated as a splat.  This fixes:
https://github.com/llvm/llvm-project/issues/55440

Added: 
    

Modified: 
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index a9ba366c61a1..8fa4e0360c29 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -809,12 +809,15 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
                                          bool &detectedSplat) {
   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
+  int64_t numElements = type.getNumElements();
+
+  // The initializer is always a splat if the result type has a single element.
+  detectedSplat = numElements == 1;
 
   // Storage width of 1 is special as it is packed by the bit.
   if (storageWidth == 1) {
     // Check for a splat, or a buffer equal to the number of elements which
     // consists of either all 0's or all 1's.
-    detectedSplat = false;
     if (rawBuffer.size() == 1) {
       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
       if (rawByte == 0 || rawByte == 0xff) {
@@ -822,12 +825,20 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
         return true;
       }
     }
-    return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
+
+    // This is a valid non-splat buffer if it has the right size.
+    return rawBufferWidth == llvm::alignTo<8>(numElements);
   }
-  // All other types are 8-bit aligned.
-  if ((detectedSplat = rawBufferWidth == storageWidth))
+
+  // All other types are 8-bit aligned, so we can just check the buffer width
+  // to know if only a single initializer element was passed in.
+  if (rawBufferWidth == storageWidth) {
+    detectedSplat = true;
     return true;
-  return rawBufferWidth == (storageWidth * type.getNumElements());
+  }
+
+  // The raw buffer is valid if it has the right size.
+  return rawBufferWidth == storageWidth * numElements;
 }
 
 /// Check the information for a C++ data type, check if this type is valid for

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index f9d8164bf5cd..934ca583f330 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
 func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
@@ -7,8 +7,6 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @add_zero_
diff erent_shape
 func.func @add_zero_
diff erent_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
   // CHECK: tosa.add
@@ -18,8 +16,6 @@ func.func @add_zero_
diff erent_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32>
 }
 
 
-// -----
-
 // CHECK-LABEL: @add_zero_int
 func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -29,8 +25,6 @@ func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @cast_fold
 func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -38,8 +32,6 @@ func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @cast_nofold
 func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: "tosa.cast"
@@ -47,8 +39,6 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   return %0 : tensor<?x1xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @clamp_not_noop
 func.func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK: "tosa.clamp"
@@ -56,8 +46,6 @@ func.func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   return %0 : tensor<4xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @clamp_float_is_noop
 func.func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: return %arg0
@@ -66,8 +54,6 @@ func.func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @clamp_int8_is_noop
 func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: return %arg0
@@ -76,8 +62,6 @@ func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   return %0 : tensor<4xi8>
 }
 
-// -----
-
 // CHECK-LABEL: @clamp_int16_is_noop
 func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
   // CHECK: return %arg0
@@ -86,8 +70,6 @@ func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
   return %0 : tensor<4xi16>
 }
 
-// -----
-
 // CHECK-LABEL: @clamp_uint8_is_noop
 func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
   // CHECK: return %arg0
@@ -96,8 +78,6 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
   return %0 : tensor<4xui8>
 }
 
-// -----
-
 // CHECK-LABEL: @clamp_twice_is_single_clamp
 func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: "tosa.clamp"(%arg0) {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
@@ -106,8 +86,6 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   return %1 : tensor<4xi8>
 }
 
-// -----
-
 // CHECK-LABEL: @concat_fold
 func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -115,8 +93,6 @@ func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @concat_fold_cast
 func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   // CHECK: %[[VAR0:.*]] = tensor.cast %arg0
@@ -125,8 +101,6 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   return %0 : tensor<?x?xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @conv2d_stride_2
 func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: "tosa.conv2d"
@@ -136,8 +110,6 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32
   return %0 : tensor<4x10x10x3xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @conv2d_weight_2x2
 func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
   // CHECK: "tosa.conv2d"
@@ -147,8 +119,6 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf
   return %0 : tensor<4x10x10x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @depthwise_conv2d_stride_2
 func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: "tosa.depthwise_conv2d"
@@ -156,8 +126,6 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
   return %0 : tensor<4x10x10x6xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @depthwise_conv2d_weight_2x2
 func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: "tosa.depthwise_conv2d"
@@ -165,8 +133,6 @@ func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tens
   return %0 : tensor<4x10x10x6xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @max_pool2d_is_noop
 func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
   // CHECK-NOT: "tosa.max_pool2d"
@@ -175,8 +141,6 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
   return %0 : tensor<10x1x1x3xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @pad_noop
 func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %arg0
@@ -185,8 +149,6 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   return %1 : tensor<?x?xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @pad_determine_val_i32
 func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
@@ -196,8 +158,6 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
   return %1 : tensor<?x?xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @pad_determine_val_f32
 func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
@@ -207,8 +167,6 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
   return %1 : tensor<?x?xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @pad_determine_val_quant
 func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor<i32>}
@@ -218,8 +176,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
   return %1 : tensor<?x?xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @mul_one_
diff erent_shape
 func.func @mul_one_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
   // CHECK: tosa.mul
@@ -228,8 +184,6 @@ func.func @mul_one_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32>
   return %1 : tensor<4x2x3xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @mul_one_float
 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
@@ -239,8 +193,6 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   return %1 : tensor<2x3xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @mul_one_int
 func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -250,8 +202,6 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -260,8 +210,6 @@ func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> t
   return %0 : tensor<2x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @select_true_value
 func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -271,8 +219,6 @@ func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
   return %0 : tensor<2x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @select_false_value
 func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@@ -282,8 +228,6 @@ func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
   return %0 : tensor<2x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @select_not_pred
 func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1>
@@ -292,8 +236,6 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
   return %1 : tensor<2x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_all_fold
 func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -301,8 +243,6 @@ func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_all_nofold
 func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: "tosa.reduce_all"
@@ -310,8 +250,6 @@ func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_any_fold
 func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -319,8 +257,6 @@ func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_any_nofold
 func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: "tosa.reduce_any"
@@ -328,8 +264,6 @@ func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_max_fold
 func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -337,8 +271,6 @@ func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_max_nofold
 func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: "tosa.reduce_max"
@@ -346,8 +278,6 @@ func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_min_fold
 func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -355,8 +285,6 @@ func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_min_nofold
 func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: "tosa.reduce_min"
@@ -364,8 +292,6 @@ func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_prod_fold
 func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -373,8 +299,6 @@ func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_prod_nofold
 func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: "tosa.reduce_prod"
@@ -382,8 +306,6 @@ func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_sum_fold
 func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
@@ -391,8 +313,6 @@ func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reduce_sum_nofold
 func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: "tosa.reduce_sum"
@@ -400,8 +320,6 @@ func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reshape_canonicalize
 func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   // CHECK: return %arg0
@@ -409,8 +327,6 @@ func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   return %0 : tensor<?x10xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reshape_canonicalize_double
 func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
   // CHECK: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = [-1, 5]}
@@ -420,8 +336,6 @@ func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf3
   return %1 : tensor<?x5xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @reshape_canonicalize_const
 func.func @reshape_canonicalize_const() -> tensor<1x10xi32> {
   // CHECK: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>}
@@ -431,8 +345,6 @@ func.func @reshape_canonicalize_const() -> tensor<1x10xi32> {
   return %1 : tensor<1x10xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @reshape_canonicalize_const_spat
 func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) {
   // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<10xi32>}
@@ -443,8 +355,6 @@ func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32
   return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @reshape_canonicalize_const_sparse
 func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
   //CHECK: "tosa.reshape"
@@ -453,8 +363,6 @@ func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32
   return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @slice_fold
 func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -462,8 +370,6 @@ func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   return %0 : tensor<3x4xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @slice_nofold
 func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   // CHECK: "tosa.slice"
@@ -471,8 +377,6 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   return %0 : tensor<?x4xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @tile_fold
 func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -480,8 +384,6 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   return %0 : tensor<3x4xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @tile_nofold
 func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   // CHECK: "tosa.tile"
@@ -489,8 +391,6 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   return %0 : tensor<3x8xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_fold
 func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -499,8 +399,6 @@ func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   return %1 : tensor<3x4xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_nofold
 func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
   // CHECK: "tosa.transpose"
@@ -509,8 +407,6 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
   return %1 : tensor<3x3xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_nofold_shape
 func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
   // CHECK: "tosa.transpose"
@@ -519,8 +415,6 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
   return %1 : tensor<?x?xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_fold_splat
 func.func @transpose_fold_splat() -> tensor<3x2xf32> {
   %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -532,8 +426,6 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> {
   return %1 : tensor<3x2xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_fold_2d_float
 func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
   %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -545,8 +437,6 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
   return %1 : tensor<3x2xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_fold_4d_int
 func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
   %input = "tosa.const"() {value = dense<[[
@@ -565,8 +455,6 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
   return %1 : tensor<3x1x4x2xi32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_nofold_non_cst_input
 func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
   %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
@@ -575,8 +463,6 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2
   return %1 : tensor<3x2xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_nofold_non_cst_perms
 func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
   %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -585,8 +471,6 @@ func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf
   return %1 : tensor<3x2xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_nofold_multi_users
 func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
   %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@@ -596,8 +480,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
   return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_nofold_quantized_types
 func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
   %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
@@ -607,8 +489,6 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<
   return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_no_op
 func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
   // CHECK: return %arg0
@@ -617,3 +497,12 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
   %1 = "tosa.transpose"(%arg0, %perms) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32>
   return %1 : tensor<3x4x5x6xf32>
 }
+
+// CHECK-LABEL: @single_bit_reshape
+// https://github.com/llvm/llvm-project/issues/55440
+func.func @single_bit_reshape() -> tensor<1xi1> {
+  // CHECK: "tosa.const"() {value = dense<true> : tensor<1xi1>}
+  %0 = arith.constant dense<true> : tensor<1x1xi1>
+  %1 = "tosa.reshape"(%0) {new_shape = [1]} : (tensor<1x1xi1>) -> tensor<1xi1>
+  return %1 : tensor<1xi1>
+}


        


More information about the Mlir-commits mailing list