[llvm-branch-commits] [mlir] 89ae5b5 - [mlir] Add canonicalization pattern out_tensor->linalg->dim to out_tensor->dim.

Alexander Belyaev via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 5 06:20:41 PST 2021


Author: Alexander Belyaev
Date: 2021-01-05T15:15:21+01:00
New Revision: 89ae5b5b6a475addb7248ca7a948a944a15f0275

URL: https://github.com/llvm/llvm-project/commit/89ae5b5b6a475addb7248ca7a948a944a15f0275
DIFF: https://github.com/llvm/llvm-project/commit/89ae5b5b6a475addb7248ca7a948a944a15f0275.diff

LOG: [mlir] Add canonicalization pattern out_tensor->linalg->dim to out_tensor->dim.

Differential Revision: https://reviews.llvm.org/D94079

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bcbd6d9036121..529ba35a0b87d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1958,14 +1958,33 @@ struct DeduplicateInputs : public RewritePattern {
     return success();
   }
 };
+
+/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
+/// with the corresponding output tensor argument of the linalg op.
+struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    Value dimOpArg = dimOp.memrefOrTensor();
+    auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
+    if (!linalgOp)
+      return failure();
+
+    auto results = linalgOp.getOperation()->getResults();
+    int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
+    auto outputTensors = linalgOp.getOutputTensors();
+    rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
+    return success();
+  }
+};
 } // namespace
 
 #define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
   void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
                                         MLIRContext *context) {                \
-    results.insert<EraseDeadLinalgOp>();                                       \
-    results.insert<FoldTensorCastOp>();                                        \
-    results.insert<DeduplicateInputs>();                                       \
+    results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>();  \
+    results.insert<ReplaceDimOfLinalgResult>(context);                         \
   }                                                                            \
                                                                                \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f015d5fd64fd9..faac64c0d91a9 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -389,3 +389,31 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
 //      CHECK: func @init_tensor_dynamic_dim
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
 //      CHECK:   return %[[ARG0]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+    %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %0, %1 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel"]
+  } ins(%arg_0 : tensor<?xf32>)
+    outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+  ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+    linalg.yield %in, %in : f32, f32
+  } -> tensor<?xf32>, tensor<?xf32>
+
+  %c0 = constant 0 : index
+  %num_elem_0 = dim %0, %c0 : tensor<?xf32>
+  %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
+
+  %num_elem_1 = dim %1, %c0 : tensor<?xf32>
+  %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
+  return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
+}
+// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
+// CHECK: dim [[ARG_0]]
+// CHECK: dim [[ARG_1]]


        


More information about the llvm-branch-commits mailing list