[llvm-branch-commits] [mlir] 722ae10 - [mlir][Linalg] Add canonicalization to remove no-op linalg operations.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 14 15:04:16 PST 2021


Author: MaheshRavishankar
Date: 2021-01-14T14:59:24-08:00
New Revision: 722ae10907e06a0bafa00c557e5242b53419a3ce

URL: https://github.com/llvm/llvm-project/commit/722ae10907e06a0bafa00c557e5242b53419a3ce
DIFF: https://github.com/llvm/llvm-project/commit/722ae10907e06a0bafa00c557e5242b53419a3ce.diff

LOG: [mlir][Linalg] Add canonicalization to remove no-op linalg operations.

linalg.generic/indexed_generic operations on tensors whose body is
just yielding the (non-induction variable) arguments of the operation
can be canonicalized by replacing uses of the result with the
corresponding arguments.

Differential Revision: https://reviews.llvm.org/D94581

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8732065bb042..b74e44d91176 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2119,6 +2119,54 @@ struct DeduplicateInputs : public RewritePattern {
   }
 };
 
+/// Remove generic/indexed_generic operations (on tensors) that are just copying
+/// the values from inputs to the results. Requirements are
+/// 1) All iterator types are parallel
+/// 2) The body contains just a yield operation with the yielded values being
+///    the arguments corresponding to the operands.
+struct RemoveIdentityLinalgOps : public RewritePattern {
+  RemoveIdentityLinalgOps(PatternBenefit benefit = 1)
+      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (!isa<GenericOp, IndexedGenericOp>(op))
+      return failure();
+    LinalgOp genericOp = cast<LinalgOp>(op);
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+    // Check all indexing maps are identity.
+    if (llvm::any_of(genericOp.getIndexingMaps(),
+                     [](AffineMap map) { return !map.isIdentity(); }))
+      return failure();
+
+    // Check that the body of the linalg operation is just a linalg.yield
+    // operation.
+    Block &body = op->getRegion(0).front();
+    if (!llvm::hasSingleElement(body))
+      return failure();
+    auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+    if (!yieldOp)
+      return failure();
+
+    // Get the argument number of the returned values. That is the operand
+    // number to use for replacing uses of this operation.
+    unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables();
+    SmallVector<Value, 4> returnedArgs;
+    for (Value yieldVal : yieldOp.values()) {
+      auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+      if (!yieldArg)
+        return failure();
+      unsigned argumentNumber = yieldArg.getArgNumber();
+      if (argumentNumber < numIndexArgs)
+        return failure();
+      returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
+    }
+    rewriter.replaceOp(genericOp, returnedArgs);
+    return success();
+  }
+};
+
 /// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
 /// with the corresponding output tensor argument of the linalg op.
 struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
@@ -2143,7 +2191,8 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
 #define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
   void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
                                         MLIRContext *context) {                \
-    results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>();  \
+    results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,     \
+                   RemoveIdentityLinalgOps>();                                 \
     results.insert<ReplaceDimOfLinalgResult>(context);                         \
   }                                                                            \
                                                                                \

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 6b806c801341..b2de3fdc6c8e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -249,8 +249,10 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
   return %1: tensor<0xf32>
 }
 // CHECK-LABEL: @dce_zero_memref
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
 //   CHECK-NOT:   linalg.copy
-//  CHECK-NEXT:   linalg.generic
+//  CHECK-NEXT:   return %[[ARG1]]
 
 // -----
 
@@ -449,3 +451,30 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
 //      CHECK:   %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
 //      CHECK:   %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
 //      CHECK:   return %[[T1]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0 = dim %arg0, %c0 : tensor<?x?x?xf32>
+  %1 = dim %arg0, %c1 : tensor<?x?x?xf32>
+  %2 = dim %arg0, %c2 : tensor<?x?x?xf32>
+  %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+  %4, %5 = linalg.generic {
+    indexing_maps = [#map, #map, #map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32):
+    linalg.yield %arg3, %arg2 : f32, f32
+  } -> tensor<?x?x?xf32>, tensor<?x?x?xf32>
+  return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @remove_no_op
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:     return %[[ARG1]], %[[ARG0]]


        


More information about the llvm-branch-commits mailing list