[llvm-branch-commits] [mlir] 722ae10 - [mlir][Linalg] Add canonicalization to remove no-op linalg operations.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 14 15:04:16 PST 2021
Author: MaheshRavishankar
Date: 2021-01-14T14:59:24-08:00
New Revision: 722ae10907e06a0bafa00c557e5242b53419a3ce
URL: https://github.com/llvm/llvm-project/commit/722ae10907e06a0bafa00c557e5242b53419a3ce
DIFF: https://github.com/llvm/llvm-project/commit/722ae10907e06a0bafa00c557e5242b53419a3ce.diff
LOG: [mlir][Linalg] Add canonicalization to remove no-op linalg operations.
linalg.generic/indexed_generic operations on tensors whose body is
just yielding the (non-induction variable) arguments of the operation
can be canonicalized by replacing uses of the result with the
corresponding arguments.
Differential Revision: https://reviews.llvm.org/D94581
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8732065bb042..b74e44d91176 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2119,6 +2119,54 @@ struct DeduplicateInputs : public RewritePattern {
}
};
+/// Remove generic/indexed_generic operations (on tensors) that are just copying
+/// the values from inputs to the results. Requirements are
+/// 1) All iterator types are parallel
+/// 2) The body contains just a yield operation with the yielded values being
+/// the arguments corresponding to the operands.
+struct RemoveIdentityLinalgOps : public RewritePattern {
+ RemoveIdentityLinalgOps(PatternBenefit benefit = 1)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (!isa<GenericOp, IndexedGenericOp>(op))
+ return failure();
+ LinalgOp genericOp = cast<LinalgOp>(op);
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ // Check all indexing maps are identity.
+ if (llvm::any_of(genericOp.getIndexingMaps(),
+ [](AffineMap map) { return !map.isIdentity(); }))
+ return failure();
+
+ // Check that the body of the linalg operation is just a linalg.yield
+ // operation.
+ Block &body = op->getRegion(0).front();
+ if (!llvm::hasSingleElement(body))
+ return failure();
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+ if (!yieldOp)
+ return failure();
+
+ // Get the argument number of the returned values. That is the operand
+ // number to use for replacing uses of this operation.
+ unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables();
+ SmallVector<Value, 4> returnedArgs;
+ for (Value yieldVal : yieldOp.values()) {
+ auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+ if (!yieldArg)
+ return failure();
+ unsigned argumentNumber = yieldArg.getArgNumber();
+ if (argumentNumber < numIndexArgs)
+ return failure();
+ returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
+ }
+ rewriter.replaceOp(genericOp, returnedArgs);
+ return success();
+ }
+};
+
/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
/// with the corresponding output tensor argument of the linalg op.
struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
@@ -2143,7 +2191,8 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
MLIRContext *context) { \
- results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>(); \
+ results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
+ RemoveIdentityLinalgOps>(); \
results.insert<ReplaceDimOfLinalgResult>(context); \
} \
\
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 6b806c801341..b2de3fdc6c8e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -249,8 +249,10 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
return %1: tensor<0xf32>
}
// CHECK-LABEL: @dce_zero_memref
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
// CHECK-NOT: linalg.copy
-// CHECK-NEXT: linalg.generic
+// CHECK-NEXT: return %[[ARG1]]
// -----
@@ -449,3 +451,30 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
// CHECK: return %[[T1]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = dim %arg0, %c0 : tensor<?x?x?xf32>
+ %1 = dim %arg0, %c1 : tensor<?x?x?xf32>
+ %2 = dim %arg0, %c2 : tensor<?x?x?xf32>
+ %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+ %4, %5 = linalg.generic {
+ indexing_maps = [#map, #map, #map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg3, %arg2 : f32, f32
+ } -> tensor<?x?x?xf32>, tensor<?x?x?xf32>
+ return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @remove_no_op
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: return %[[ARG1]], %[[ARG0]]
More information about the llvm-branch-commits
mailing list