[Mlir-commits] [mlir] 19435d3 - [mlir][linalg] Fold fill -> tensor_reshape chain

Lei Zhang llvmlistbot at llvm.org
Wed Mar 24 15:19:43 PDT 2021


Author: Lei Zhang
Date: 2021-03-24T18:17:58-04:00
New Revision: 19435d3863e512c62962dd872f5cbf630eaeab73

URL: https://github.com/llvm/llvm-project/commit/19435d3863e512c62962dd872f5cbf630eaeab73
DIFF: https://github.com/llvm/llvm-project/commit/19435d3863e512c62962dd872f5cbf630eaeab73.diff

LOG: [mlir][linalg] Fold fill -> tensor_reshape chain

For such op chains, we can create new linalg.fill ops
with the result type of the linalg.tensor_reshape op.

Differential Revision: https://reviews.llvm.org/D99116

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9f8ade5b5fbca..fdb2e4f4603ee 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1658,12 +1658,33 @@ struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
     return success();
   }
 };
+
+/// Fold linalg.fill -> linalg.tensor_reshape chain.
+///
+/// For such op chains, we can create new linalg.fill ops with the result
+/// type of the linalg.tensor_reshape op.
+struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto oldFill = reshapeOp.src().getDefiningOp<FillOp>();
+    if (!oldFill)
+      return failure();
+
+    auto newInit = rewriter.create<InitTensorOp>(
+        oldFill.getLoc(), reshapeOp.getResultType().getShape(),
+        reshapeOp.getResultType().getElementType());
+    rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, newInit, oldFill.value());
+
+    return success();
+  }
+};
 } // namespace
 
 void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant,
-              ReplaceDimOfReshapeOpResult>(context);
+  results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
+              FoldReshapeWithConstant, ReplaceDimOfReshapeOpResult>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 693e94f636983..5ec93dda59d0a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -802,3 +802,19 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
 //   CHECK: return
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_reshape()
+func @fold_fill_reshape() -> tensor<6x4xf32> {
+  %zero = constant 0.0 : f32
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32>
+  %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32>
+  // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<6x4xf32>, f32 -> tensor<6x4xf32>
+  %fill = linalg.fill(%init, %zero) : tensor<1x2x3x4xf32>, f32 -> tensor<1x2x3x4xf32>
+  %reshape = linalg.tensor_reshape %fill [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+    affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<1x2x3x4xf32> into tensor<6x4xf32>
+  // CHECK: return %[[FILL]] : tensor<6x4xf32>
+  return %reshape : tensor<6x4xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 54898d6f03527..cb5d1089eb85d 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -379,12 +379,14 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
 
 //       CHECK: func @fold_unit_dim_for_init_tensor
 
+
 //       CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [#[[MAP0]]] : tensor<1x1000xf32> into tensor<1000xf32>
-//       CHECK: %[[INIT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] : tensor<1xf32> into tensor<f32>
+//       CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor<f32>
+//       CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<f32>, f32 -> tensor<f32>
 //       CHECK: %[[GENERIC:.+]] = linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
 //  CHECK-SAME:     iterator_types = ["reduction"]
 //  CHECK-SAME:   ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
-//  CHECK-SAME:   outs(%[[INIT_RESHAPE]] : tensor<f32>)
+//  CHECK-SAME:   outs(%[[FILL]] : tensor<f32>)
 //       CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
 //       CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>


        


More information about the Mlir-commits mailing list