[Mlir-commits] [mlir] 09bd5ae - [mlir][tosa] Fix `tosa.reshape` folder for quantized constants

Robert Suderman llvmlistbot at llvm.org
Mon Apr 24 18:14:19 PDT 2023


Author: Robert Suderman
Date: 2023-04-25T01:13:20Z
New Revision: 09bd5ae49ea84c734cec35ec8555b16edb13c7b4

URL: https://github.com/llvm/llvm-project/commit/09bd5ae49ea84c734cec35ec8555b16edb13c7b4
DIFF: https://github.com/llvm/llvm-project/commit/09bd5ae49ea84c734cec35ec8555b16edb13c7b4.diff

LOG: [mlir][tosa] Fix `tosa.reshape` folder for quantized constants

It is possible for `tosa.const` to have a quantized return type.
In these cases we need to retain the expected result type to avoid
potential type mismatches further in the model.

Reviewed By: cota

Differential Revision: https://reviews.llvm.org/D149109

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 19a80c783c475..16b8e92444f68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -105,11 +105,9 @@ struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
                                          "Used more than once or not-splat");
 
     // Build new const op with correct output shape
-    ShapedType inputShape = input.getType().cast<ShapedType>();
-    DenseElementsAttr outputAttr =
-        inputAttr.reshape(inputShape.clone(op.getNewShape()));
-    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
-                                               outputAttr);
+    DenseElementsAttr outputAttr = inputAttr.reshape(
+        inputAttr.getType().cast<ShapedType>().clone(op.getNewShape()));
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultTy, outputAttr);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index bdd4021cb39a1..eacff38ebeabf 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -358,12 +358,29 @@ func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32
 
 // CHECK-LABEL: @reshape_canonicalize_const_sparse
 func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
-  //CHECK: "tosa.reshape"
+  // CHECK: "tosa.reshape"
   %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32>
   %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 3>} : (tensor<3xi32>) -> tensor<1x3xi32>
   return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
 }
 
+// CHECK-LABEL: @reshape_canonicalize_quant
+func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
+  // CHECK{literal}: "tosa.const"() {value = dense<[[1, 2, 3]]> : tensor<1x3xi8>} : () -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>> 
+  %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
+  %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 3>} : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+  return %1 :  tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+}
+
+// CHECK-LABEL: @transpose_canonicalize_strip_quant
+func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3xi8>) {
+  // CHECK: "tosa.const"() {value = dense<0> : tensor<2x1x3xi8>} : () -> tensor<2x1x3xi8>
+  %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
+  %1 = "tosa.transpose"(%0, %perms) : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3xi8>
+  return %1 :  tensor<2x1x3xi8>
+}
+
 // CHECK-LABEL: @slice_fold
 func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list