[Mlir-commits] [mlir] 0ee5323 - [mlir][linalg][bufferize][NFC] Simplify getAliasingOpOperand signature

Matthias Springer llvmlistbot at llvm.org
Tue Oct 12 17:48:40 PDT 2021


Author: Matthias Springer
Date: 2021-10-13T09:48:28+09:00
New Revision: 0ee53231894f31ebe0772a7b3cca326b04abd46c

URL: https://github.com/llvm/llvm-project/commit/0ee53231894f31ebe0772a7b3cca326b04abd46c
DIFF: https://github.com/llvm/llvm-project/commit/0ee53231894f31ebe0772a7b3cca326b04abd46c.diff

LOG: [mlir][linalg][bufferize][NFC] Simplify getAliasingOpOperand signature

Differential Revision: https://reviews.llvm.org/D110982

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index a9957ada9f59..0255b3a349ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -551,41 +551,41 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
 }
 
 /// Determine which OpOperand* will alias with `result` if the op is bufferized
-/// in place.
-/// Return None if the owner of `opOperand` does not have known
-/// bufferization aliasing behavior, which indicates that the op must allocate
-/// all of its tensor results.
-/// TODO: in the future this may need to evolve towards a list of OpOperand*.
-static Optional<OpOperand *> getAliasingOpOperand(OpResult result) {
+/// in place. Note that multiple OpOperands can may potentially alias with an
+/// OpResult. E.g.: std.select in the future.
+static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
+  SmallVector<OpOperand *> r;
+  // Unknown ops are handled conservatively and never bufferize in-place.
   if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp()))
-    return None;
-  return TypeSwitch<Operation *, OpOperand *>(result.getDefiningOp())
-      .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); })
-      .Case([&](ConstantOp op) { return nullptr; })
-      .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); })
+    return SmallVector<OpOperand *>();
+  TypeSwitch<Operation *>(result.getDefiningOp())
+      .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); })
+      .Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); })
       // In the case of scf::ForOp, this currently assumes the iter_args / yield
       // are 1-1. This may fail and is verified at the end.
       // TODO: update this.
       .Case([&](scf::ForOp op) {
-        return &op.getIterOpOperands()[result.getResultNumber()];
+        r.push_back(&op.getIterOpOperands()[result.getResultNumber()]);
       })
-      .Case([&](InitTensorOp op) { return nullptr; })
-      .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); })
+      .Case([&](InsertSliceOp op) { r.push_back(&op->getOpOperand(1)); })
       .Case([&](LinalgOp op) {
-        return op.getOutputTensorOperands()[result.getResultNumber()];
+        r.push_back(op.getOutputTensorOperands()[result.getResultNumber()]);
       })
       .Case([&](TiledLoopOp op) {
         // TODO: TiledLoopOp helper method to avoid leaking impl details.
-        return &op->getOpOperand(op.getNumControlOperands() +
-                                 op.getNumInputs() + result.getResultNumber());
+        r.push_back(&op->getOpOperand(op.getNumControlOperands() +
+                                      op.getNumInputs() +
+                                      result.getResultNumber()));
       })
-      .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); })
-      .Case([&](CallOpInterface op) { return nullptr; })
+      .Case([&](vector::TransferWriteOp op) {
+        r.push_back(&op->getOpOperand(1));
+      })
+      .Case<ConstantOp, CallOpInterface, InitTensorOp>([&](auto op) {})
       .Default([&](Operation *op) {
         op->dump();
         llvm_unreachable("unexpected defining op");
-        return nullptr;
       });
+  return r;
 }
 
 /// If the an ExtractSliceOp is bufferized in-place, the source operand will
@@ -879,8 +879,11 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
 ///      dominance).
 bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
     OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const {
-  assert(getAliasingOpOperand(result) == &operand &&
+#ifndef NDEBUG
+  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+  assert(llvm::find(opOperands, &operand) != opOperands.end() &&
          "operand and result do not match");
+#endif // NDEBUG
 
   Operation *opToBufferize = result.getDefiningOp();
   Value opResult = result;
@@ -975,8 +978,11 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
 /// a write to a non-writable buffer.
 bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
     OpOperand &opOperand, OpResult opResult) const {
-  assert(getAliasingOpOperand(opResult) == &opOperand &&
+#ifndef NDEBUG
+  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
+  assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
          "operand and result do not match");
+#endif // NDEBUG
 
   // Certain buffers are not writeable:
   //   1. A function bbArg that is not inplaceable or
@@ -1126,9 +1132,10 @@ bool BufferizationAliasInfo::existsInterleavedValueClobber(
       Operation *candidateOp = mit->v.getDefiningOp();
       if (!candidateOp)
         continue;
-      auto maybeAliasingOperand = getAliasingOpOperand(mit->v.cast<OpResult>());
-      if (!maybeAliasingOperand || !*maybeAliasingOperand ||
-          !bufferizesToMemoryWrite(**maybeAliasingOperand))
+      SmallVector<OpOperand *> operands =
+          getAliasingOpOperand(mit->v.cast<OpResult>());
+      assert(operands.size() <= 1 && "more than 1 OpOperand not supported yet");
+      if (operands.empty() || !bufferizesToMemoryWrite(*operands.front()))
         continue;
       LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)
                                          << '\n');
@@ -1414,9 +1421,11 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
                              bool skipCopy = false) {
   OpBuilder::InsertionGuard guard(b);
   Operation *op = result.getOwner();
-  Optional<OpOperand *> maybeOperand = getAliasingOpOperand(result);
-  assert(maybeOperand && "corresponding OpOperand not found");
-  Value operand = (*maybeOperand)->get();
+  SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
+  // TODO: Support multiple OpOperands.
+  assert(aliasingOperands.size() == 1 &&
+         "more than 1 OpOperand not supported yet");
+  Value operand = aliasingOperands.front()->get();
   Value operandBuffer = lookup(bvm, operand);
   assert(operandBuffer && "operand buffer not found");
 
@@ -2159,8 +2168,11 @@ static LogicalResult
 bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
                                 BufferizationAliasInfo &aliasInfo,
                                 const DominanceInfo &domInfo) {
-  assert(getAliasingOpOperand(result) == &operand &&
+#ifndef NDEBUG
+  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+  assert(llvm::find(opOperands, &operand) != opOperands.end() &&
          "operand and result do not match");
+#endif // NDEBUG
 
   int64_t resultNumber = result.getResultNumber();
   (void)resultNumber;


        


More information about the Mlir-commits mailing list