[Mlir-commits] [mlir] 9c62bb5 - Implementation of `ReshapeNoopOptimization` canonicalizer.

Rob Suderman llvmlistbot at llvm.org
Tue Oct 19 16:08:48 PDT 2021


Author: Kojo Acquah
Date: 2021-10-19T16:07:34-07:00
New Revision: 9c62bb55f473a9d0db16b894708ed09f2346ae9d

URL: https://github.com/llvm/llvm-project/commit/9c62bb55f473a9d0db16b894708ed09f2346ae9d
DIFF: https://github.com/llvm/llvm-project/commit/9c62bb55f473a9d0db16b894708ed09f2346ae9d.diff

LOG: Implementation of `ReshapeNoopOptimization` canonicalizer.

This canonicalizer replaces reshapes of constant tensors that contain the updated shape (skipping the reshape operation).

Differential Revision: https://reviews.llvm.org/D112038

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2ad14f5bf1809..744eaa0b3bb8a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -154,9 +154,43 @@ struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
   }
 };
 
+struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
+  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input1();
+    ArrayAttr newShape = op.new_shape();
+
+    // Check if input is constant
+    DenseElementsAttr inputAttr;
+    if (!matchPattern(input, m_Constant(&inputAttr)))
+      return failure();
+
+    // Check if has >1 consumer and is not splat
+    if (!input.hasOneUse() && !inputAttr.isSplat())
+      return failure();
+
+    // Grab the new shape
+    SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
+        llvm::map_range(newShape.getValue(), [](const Attribute &val) {
+          return val.cast<IntegerAttr>().getValue().getSExtValue();
+        }));
+
+    // Build new const op with correct output shape
+    ShapedType inputShape = input.getType().cast<ShapedType>();
+    DenseElementsAttr outputAttr =
+        inputAttr.reshape(inputShape.clone(newShapeValues));
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
+                                               outputAttr);
+    return success();
+  }
+};
+
 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
   results.insert<ReshapeReshapeOptimization>(context);
+  results.insert<ReshapeConstOptimization>(context);
 }
 
 struct ConstantTransposeOptimization

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 983ce58c0c652..e6cf1a15ac67f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -174,6 +174,39 @@ func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
 
 // -----
 
+// CHECK-LABEL: @reshape_canonicalize_const
+func @reshape_canonicalize_const() -> tensor<1x10xi32> {
+  // CHECK: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>}
+  // CHECK: return %[[VAR0]]
+  %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.reshape"(%0) {new_shape = [1, 10]} : (tensor<10xi32>) -> tensor<1x10xi32>
+  return %1 : tensor<1x10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_canonicalize_const_spat
+func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) {
+  // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<10xi32>}
+  // CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>}
+  // CHECK: return %[[VAR0]], %[[VAR1]]
+  %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.reshape"(%0) {new_shape = [1, 10]} : (tensor<10xi32>) -> tensor<1x10xi32>
+  return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_canonicalize_const_sparse
+func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
+  //CHECK: "tosa.reshape"
+  %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32>
+  %1 = "tosa.reshape"(%0) {new_shape = [1, 3]} : (tensor<3xi32>) -> tensor<1x3xi32>
+  return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @slice_fold
 func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list