[llvm-branch-commits] [mlir] 89ae5b5 - [mlir] Add canonicalization pattern out_tensor->linalg->dim to out_tensor->dim.
Alexander Belyaev via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 5 06:20:41 PST 2021
Author: Alexander Belyaev
Date: 2021-01-05T15:15:21+01:00
New Revision: 89ae5b5b6a475addb7248ca7a948a944a15f0275
URL: https://github.com/llvm/llvm-project/commit/89ae5b5b6a475addb7248ca7a948a944a15f0275
DIFF: https://github.com/llvm/llvm-project/commit/89ae5b5b6a475addb7248ca7a948a944a15f0275.diff
LOG: [mlir] Add canonicalization pattern out_tensor->linalg->dim to out_tensor->dim.
Differential Revision: https://reviews.llvm.org/D94079
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bcbd6d9036121..529ba35a0b87d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1958,14 +1958,33 @@ struct DeduplicateInputs : public RewritePattern {
return success();
}
};
+
+/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
+/// with the corresponding output tensor argument of the linalg op.
+struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ Value dimOpArg = dimOp.memrefOrTensor();
+ auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
+ if (!linalgOp)
+ return failure();
+
+ auto results = linalgOp.getOperation()->getResults();
+ int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
+ auto outputTensors = linalgOp.getOutputTensors();
+ rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
+ return success();
+ }
+};
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
MLIRContext *context) { \
- results.insert<EraseDeadLinalgOp>(); \
- results.insert<FoldTensorCastOp>(); \
- results.insert<DeduplicateInputs>(); \
+ results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>(); \
+ results.insert<ReplaceDimOfLinalgResult>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f015d5fd64fd9..faac64c0d91a9 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -389,3 +389,31 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
// CHECK: func @init_tensor_dynamic_dim
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+ %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %0, %1 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel"]
+ } ins(%arg_0 : tensor<?xf32>)
+ outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+ linalg.yield %in, %in : f32, f32
+ } -> tensor<?xf32>, tensor<?xf32>
+
+ %c0 = constant 0 : index
+ %num_elem_0 = dim %0, %c0 : tensor<?xf32>
+ %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
+
+ %num_elem_1 = dim %1, %c0 : tensor<?xf32>
+ %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
+ return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
+}
+// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
+// CHECK: dim [[ARG_0]]
+// CHECK: dim [[ARG_1]]
More information about the llvm-branch-commits
mailing list