[Mlir-commits] [mlir] 9b6c2ea - [mlir][Linalg] Add GenericOp self-copy on buffers folding
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jan 26 02:56:51 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-26T05:56:31-05:00
New Revision: 9b6c2ea30219c16c264eaa38609e324470e2ad07
URL: https://github.com/llvm/llvm-project/commit/9b6c2ea30219c16c264eaa38609e324470e2ad07
DIFF: https://github.com/llvm/llvm-project/commit/9b6c2ea30219c16c264eaa38609e324470e2ad07.diff
LOG: [mlir][Linalg] Add GenericOp self-copy on buffers folding
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D118116
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/inlining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3ca3932a44eec..b3067109b4365 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -843,8 +843,6 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
- return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
@@ -859,6 +857,17 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
if (!yieldOp)
return failure();
+ // In the buffer case, we need to check exact buffer equality.
+ if (genericOp.hasBufferSemantics()) {
+ if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
+ genericOp.getInputOperand(0)->get() ==
+ genericOp.getOutputOperand(0)->get()) {
+ rewriter.eraseOp(genericOp);
+ return success();
+ }
+ return failure();
+ }
+
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
@@ -876,6 +885,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
resultType, returnedArg);
returnedArgs.push_back(returnedArg);
}
+
if (returnedArgs.size() != genericOp->getNumResults())
return failure();
rewriter.replaceOp(genericOp, returnedArgs);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 44cb18f11d152..96d3aa26deafa 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -583,3 +583,19 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
%r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
return %r2 : index
}
+
+// -----
+
+// CHECK: func @fold_self_copy
+func @fold_self_copy(%0 : memref<4x16xf32>) {
+// CHECK-NEXT: return
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : memref<4x16xf32>)
+ outs(%0 : memref<4x16xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ linalg.yield %arg4 : f32
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/inlining.mlir b/mlir/test/Dialect/Linalg/inlining.mlir
index 527b044fa2499..033213c2a954c 100644
--- a/mlir/test/Dialect/Linalg/inlining.mlir
+++ b/mlir/test/Dialect/Linalg/inlining.mlir
@@ -25,7 +25,8 @@ func @inlined_fn(%arg0: memref<?xf32>) {
ins(%arg0 : memref<?xf32>)
outs(%arg0 : memref<?xf32>) {
^bb(%0 : f32, %1 : f32) :
- linalg.yield %0 : f32
+ %2 = arith.addf %0, %0: f32
+ linalg.yield %2 : f32
}
return
}
More information about the Mlir-commits
mailing list