[Mlir-commits] [mlir] 9b6c2ea - [mlir][Linalg] Add GenericOp self-copy on buffers folding

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jan 26 02:56:51 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-26T05:56:31-05:00
New Revision: 9b6c2ea30219c16c264eaa38609e324470e2ad07

URL: https://github.com/llvm/llvm-project/commit/9b6c2ea30219c16c264eaa38609e324470e2ad07
DIFF: https://github.com/llvm/llvm-project/commit/9b6c2ea30219c16c264eaa38609e324470e2ad07.diff

LOG: [mlir][Linalg] Add GenericOp self-copy on buffers folding

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D118116

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/inlining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3ca3932a44eec..b3067109b4365 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -843,8 +843,6 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
-      return failure();
     // Check all indexing maps are identity.
     if (llvm::any_of(genericOp.getIndexingMaps(),
                      [](AffineMap map) { return !map.isIdentity(); }))
@@ -859,6 +857,17 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     if (!yieldOp)
       return failure();
 
+    // In the buffer case, we need to check exact buffer equality.
+    if (genericOp.hasBufferSemantics()) {
+      if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
+          genericOp.getInputOperand(0)->get() ==
+              genericOp.getOutputOperand(0)->get()) {
+        rewriter.eraseOp(genericOp);
+        return success();
+      }
+      return failure();
+    }
+
     // Get the argument number of the returned values. That is the operand
     // number to use for replacing uses of this operation.
     SmallVector<Value> returnedArgs;
@@ -876,6 +885,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
                                                       resultType, returnedArg);
       returnedArgs.push_back(returnedArg);
     }
+
     if (returnedArgs.size() != genericOp->getNumResults())
       return failure();
     rewriter.replaceOp(genericOp, returnedArgs);

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 44cb18f11d152..96d3aa26deafa 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -583,3 +583,19 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
   %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
   return %r2 : index
 }
+
+// -----
+
+// CHECK: func @fold_self_copy
+func @fold_self_copy(%0 : memref<4x16xf32>) {
+// CHECK-NEXT: return
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 
+                                   affine_map<(d0, d1) -> (d0, d1)>], 
+                  iterator_types = ["parallel", "parallel"]} 
+    ins(%0 : memref<4x16xf32>) 
+    outs(%0 : memref<4x16xf32>) {
+      ^bb0(%arg4: f32, %arg5: f32):
+        linalg.yield %arg4 : f32
+    }
+  return 
+}

diff  --git a/mlir/test/Dialect/Linalg/inlining.mlir b/mlir/test/Dialect/Linalg/inlining.mlir
index 527b044fa2499..033213c2a954c 100644
--- a/mlir/test/Dialect/Linalg/inlining.mlir
+++ b/mlir/test/Dialect/Linalg/inlining.mlir
@@ -25,7 +25,8 @@ func @inlined_fn(%arg0: memref<?xf32>) {
      ins(%arg0 : memref<?xf32>)
     outs(%arg0 : memref<?xf32>) {
     ^bb(%0 : f32, %1 : f32) :
-      linalg.yield %0 : f32
+      %2 = arith.addf %0, %0: f32
+      linalg.yield %2 : f32
   }
   return
 }


        


More information about the Mlir-commits mailing list