[Mlir-commits] [mlir] 7c9b6a3 - [mlir][linalg] ComprehensiveBufferize: Do not copy InitTensorOps

Matthias Springer llvmlistbot at llvm.org
Mon Sep 13 06:32:10 PDT 2021


Author: Matthias Springer
Date: 2021-09-13T22:31:54+09:00
New Revision: 7c9b6a3355ee8226a880edcf88302bc0360f33b5

URL: https://github.com/llvm/llvm-project/commit/7c9b6a3355ee8226a880edcf88302bc0360f33b5
DIFF: https://github.com/llvm/llvm-project/commit/7c9b6a3355ee8226a880edcf88302bc0360f33b5.diff

LOG: [mlir][linalg] ComprehensiveBufferize: Do not copy InitTensorOps

Do not copy InitTensorOps or casts thereof.

Differential Revision: https://reviews.llvm.org/D109656

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index cb7f0c39d304..3d810bfd48ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -172,6 +172,14 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
   return returnOp;
 }
 
+/// Return true if `value` is the result of an InitTensorOp or a cast thereof.
+static bool isInitTensorOp(Value value) {
+  tensor::CastOp castOp;
+  while ((castOp = value.getDefiningOp<tensor::CastOp>()))
+    value = castOp.source();
+  return value.getDefiningOp<InitTensorOp>();
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//
@@ -1781,7 +1789,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
       // unitialized and we do not need to copy.
       // TODO: "matching bbArg does not bufferize to a read" is a more general
       // check.
-      if (!operand.getDefiningOp<linalg::InitTensorOp>())
+      if (!isInitTensorOp(operand))
         b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
     }
     BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
@@ -1908,7 +1916,7 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
       // unitialized and we do not need to copy.
       // TODO: "matching bbArg does not bufferize to a read" is a more general
       // check.
-      if (!oldOutputTensor.getDefiningOp<linalg::InitTensorOp>()) {
+      if (!isInitTensorOp(oldOutputTensor)) {
         b.setInsertionPointAfter(alloc.getDefiningOp());
         b.create<linalg::CopyOp>(loc, outputBuffer, alloc);
       }


        


More information about the Mlir-commits mailing list