[Mlir-commits] [mlir] 9739ef6 - [MLIR][Tosa] Turn reshape(const()) from canonicalization into fold; fix dynamic shape case

Matthias Gehre llvmlistbot at llvm.org
Thu Jul 6 11:07:30 PDT 2023


Author: Liam Fitzpatrick
Date: 2023-07-06T20:07:24+02:00
New Revision: 9739ef67a8f79ed88d4e28710c6d5c67d5566425

URL: https://github.com/llvm/llvm-project/commit/9739ef67a8f79ed88d4e28710c6d5c67d5566425
DIFF: https://github.com/llvm/llvm-project/commit/9739ef67a8f79ed88d4e28710c6d5c67d5566425.diff

LOG: [MLIR][Tosa] Turn reshape(const()) from canonicalization into fold; fix dynamic shape case

1) Turns the canonicalization into a fold, so it can cleanup IR within other passes.

2) When the output of the reshape is a dynamic shaped tensor,
we cannot apply the fold to the constant, because constants are required to have
static shape.

Differential Revision: https://reviews.llvm.org/D154615

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5af10ec80a5c90..8cefa64bc4c1fe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -82,40 +82,9 @@ struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
   }
 };
 
-struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
-  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getInput1();
-    ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
-    ShapedType resultTy = llvm::cast<ShapedType>(op.getType());
-
-    if (inputTy.getElementType() != resultTy.getElementType())
-      return rewriter.notifyMatchFailure(op, "element type does not match.");
-
-    // Check if input is constant
-    DenseElementsAttr inputAttr;
-    if (!matchPattern(input, m_Constant(&inputAttr)))
-      return rewriter.notifyMatchFailure(op, "Non-constant input.");
-
-    // Check if has >1 consumer and is not splat
-    if (!input.hasOneUse() && !inputAttr.isSplat())
-      return rewriter.notifyMatchFailure(op,
-                                         "Used more than once or not-splat");
-
-    // Build new const op with correct output shape
-    DenseElementsAttr outputAttr = inputAttr.reshape(
-        llvm::cast<ShapedType>(inputAttr.getType()).clone(op.getNewShape()));
-    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultTy, outputAttr);
-    return success();
-  }
-};
-
 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ReshapeReshapeOptimization>(context);
-  results.add<ReshapeConstOptimization>(context);
 }
 
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
@@ -851,12 +820,25 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   if (inputTy == outputTy)
     return getInput1();
 
+  // Constants must have static shape.
+  if (!outputTy.hasStaticShape())
+    return {};
+
   auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
+  if (!operand)
+    return {};
+
+  // Okay to duplicate splat constants.
+  if (operand.isSplat()) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
 
-  return {};
+  // Don't duplicate other constants.
+  if (!getInput1().hasOneUse())
+    return {};
+
+  return operand.reshape(
+      llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
 }
 
 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 3f6d0e0b014112..1f9ce1f4c7b929 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -375,16 +375,24 @@ func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf3
 }
 
 // CHECK-LABEL: @reshape_canonicalize_const
-func.func @reshape_canonicalize_const() -> tensor<1x10xi32> {
-  // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>}
+func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
+  // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>}
   // CHECK: return %[[VAR0]]
-  %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
-  %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 10>} : (tensor<10xi32>) -> tensor<1x10xi32>
-  return %1 : tensor<1x10xi32>
+  %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
+  %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 5>} : (tensor<5xi32>) -> tensor<1x5xi32>
+  return %1 : tensor<1x5xi32>
+}
+
+// CHECK-LABEL: @reshape_canonicalize_const_dynamic
+func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
+  // CHECK: tosa.reshape
+  %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
+  %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 5>} : (tensor<5xi32>) -> tensor<1x?xi32>
+  return %1 : tensor<1x?xi32>
 }
 
-// CHECK-LABEL: @reshape_canonicalize_const_spat
-func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) {
+// CHECK-LABEL: @reshape_canonicalize_const_splat
+func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi32>) {
   // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<10xi32>}
   // CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>}
   // CHECK: return %[[VAR0]], %[[VAR1]]


        


More information about the Mlir-commits mailing list