[Mlir-commits] [mlir] 19435d3 - [mlir][linalg] Fold fill -> tensor_reshape chain
Lei Zhang
llvmlistbot at llvm.org
Wed Mar 24 15:19:43 PDT 2021
Author: Lei Zhang
Date: 2021-03-24T18:17:58-04:00
New Revision: 19435d3863e512c62962dd872f5cbf630eaeab73
URL: https://github.com/llvm/llvm-project/commit/19435d3863e512c62962dd872f5cbf630eaeab73
DIFF: https://github.com/llvm/llvm-project/commit/19435d3863e512c62962dd872f5cbf630eaeab73.diff
LOG: [mlir][linalg] Fold fill -> tensor_reshape chain
For such op chains, we can create new linalg.fill ops
with the result type of the linalg.tensor_reshape op.
Differential Revision: https://reviews.llvm.org/D99116
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9f8ade5b5fbca..fdb2e4f4603ee 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1658,12 +1658,33 @@ struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
return success();
}
};
+
+/// Fold linalg.fill -> linalg.tensor_reshape chain.
+///
+/// For such op chains, we can create new linalg.fill ops with the result
+/// type of the linalg.tensor_reshape op.
+struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto oldFill = reshapeOp.src().getDefiningOp<FillOp>();
+ if (!oldFill)
+ return failure();
+
+ auto newInit = rewriter.create<InitTensorOp>(
+ oldFill.getLoc(), reshapeOp.getResultType().getShape(),
+ reshapeOp.getResultType().getElementType());
+ rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, newInit, oldFill.value());
+
+ return success();
+ }
+};
} // namespace
void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant,
- ReplaceDimOfReshapeOpResult>(context);
+ results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
+ FoldReshapeWithConstant, ReplaceDimOfReshapeOpResult>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 693e94f636983..5ec93dda59d0a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -802,3 +802,19 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
// CHECK: return
return
}
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_reshape()
+func @fold_fill_reshape() -> tensor<6x4xf32> {
+ %zero = constant 0.0 : f32
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32>
+ %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32>
+ // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<6x4xf32>, f32 -> tensor<6x4xf32>
+ %fill = linalg.fill(%init, %zero) : tensor<1x2x3x4xf32>, f32 -> tensor<1x2x3x4xf32>
+ %reshape = linalg.tensor_reshape %fill [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<1x2x3x4xf32> into tensor<6x4xf32>
+ // CHECK: return %[[FILL]] : tensor<6x4xf32>
+ return %reshape : tensor<6x4xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 54898d6f03527..cb5d1089eb85d 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -379,12 +379,14 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
// CHECK: func @fold_unit_dim_for_init_tensor
+
// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [#[[MAP0]]] : tensor<1x1000xf32> into tensor<1000xf32>
-// CHECK: %[[INIT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] : tensor<1xf32> into tensor<f32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<f32>, f32 -> tensor<f32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["reduction"]
// CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
-// CHECK-SAME: outs(%[[INIT_RESHAPE]] : tensor<f32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<f32>)
// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>
More information about the Mlir-commits
mailing list