[Mlir-commits] [mlir] 585a8a3 - [mlir][bufferize] OpOperands can have multiple aliasing OpResults

Matthias Springer llvmlistbot at llvm.org
Wed Feb 9 04:03:28 PST 2022


Author: Matthias Springer
Date: 2022-02-09T20:58:45+09:00
New Revision: 585a8a321c263a6c793ac800f05f2805f45a4feb

URL: https://github.com/llvm/llvm-project/commit/585a8a321c263a6c793ac800f05f2805f45a4feb
DIFF: https://github.com/llvm/llvm-project/commit/585a8a321c263a6c793ac800f05f2805f45a4feb.diff

LOG: [mlir][bufferize] OpOperands can have multiple aliasing OpResults

This makes getAliasingOpResult symmetric to getAliasingOpOperand. The previous implementation was confusing for users and implemented in such a way only because there are currently no bufferizable ops that have multiple aliasing OpResults.

Differential Revision: https://reviews.llvm.org/D119259

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 609a1bb520c9d..714f75d09b965 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -180,9 +180,8 @@ class BufferizationState {
   SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
 
   /// Determine which OpResult will alias with `opOperand` if the op is
-  /// bufferized in place. Return an empty OpResult if the op is not
-  /// bufferizable.
-  OpResult getAliasingOpResult(OpOperand &opOperand) const;
+  /// bufferized in place. Return an empty vector if the op is not bufferizable.
+  SmallVector<OpResult> getAliasingOpResult(OpOperand &opOperand) const;
 
   /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
   /// the op is not bufferizable.
@@ -396,9 +395,10 @@ struct AllocationHoistingBarrierOnly
     return {};
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index e78d00ab2ac80..f6c51dae92eaf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -124,7 +124,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           bufferized in-place. This method will never be called on OpOperands
           that do not have a tensor type.
         }],
-        /*retType=*/"OpResult",
+        /*retType=*/"SmallVector<OpResult>",
         /*methodName=*/"getAliasingOpResult",
         /*args=*/(ins "OpOperand &":$opOperand,
                       "const BufferizationState &":$state),
@@ -162,8 +162,10 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) {
             if (!opOperand.get().getType().isa<TensorType>())
               continue;
-            if (bufferizableOp.getAliasingOpResult(opOperand, state) ==
-                    opResult)
+            SmallVector<OpResult> aliasingOpResults =
+                bufferizableOp.getAliasingOpResult(opOperand, state);
+            if (llvm::find(aliasingOpResults, opResult)
+                != aliasingOpResults.end())
               result.push_back(&opOperand);
           }
           return result;
@@ -304,8 +306,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           cast<BufferizableOpInterface>(getOperation());
       return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
           && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state)
-          && static_cast<bool>(
-              bufferizableOp.getAliasingOpResult(opOperand, state));
+          && !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
     }
 
     // TODO: The following two attributes should belong to the tensor dialect.

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index a69972b9e2822..559f9b4380813 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -211,9 +211,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
       return true;
     }
 
-    OpResult getAliasingOpResult(OpOperand &opOperand,
-                                 const BufferizationState &state) const {
-      return OpResult();
+    SmallVector<OpResult> getAliasingOpResult(
+        OpOperand &opOperand, const BufferizationState &state) const {
+      return {};
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 02526264ce4f3..3a01b397abef7 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -69,9 +69,10 @@ struct IndexCastOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return op->getResult(0);
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {op->getResult(0)};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -114,9 +115,10 @@ struct SelectOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return op->getOpResult(0) /*result*/;
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {op->getOpResult(0) /*result*/};
   }
 
   SmallVector<OpOperand *>

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f403710437d27..261a107c0e56e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -87,12 +87,13 @@ BufferizationState::getAliasingOpOperand(OpResult result) const {
 }
 
 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
-/// in place. Return an empty OpResult if the op is not bufferizable.
-OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
+/// in place. Return an empty vector if the op is not bufferizable.
+SmallVector<OpResult>
+BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
   if (auto bufferizableOp =
           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
     return bufferizableOp.getAliasingOpResult(opOperand, *this);
-  return OpResult();
+  return {};
 }
 
 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
@@ -144,8 +145,9 @@ bool BufferizationState::isValueRead(Value value) const {
     OpOperand *uMaybeReading = workingSet.pop_back_val();
     // Skip over all ops that neither read nor write (but create an alias).
     if (bufferizesToAliasOnly(*uMaybeReading))
-      for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses())
-        workingSet.push_back(&use);
+      for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
+        for (OpOperand &use : opResult.getUses())
+          workingSet.push_back(&use);
     if (bufferizesToMemoryRead(*uMaybeReading))
       return true;
   }
@@ -266,9 +268,10 @@ FailureOr<Value> BufferizationState::getBuffer(
       }))
     return resultBuffer;
   // Do not copy if the copied data is never read.
-  OpResult aliasingOpResult = getAliasingOpResult(opOperand);
-  if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
-      !isValueRead(aliasingOpResult))
+  SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
+  if (!aliasingOpResults.empty() && !bufferizesToMemoryRead(opOperand) &&
+      llvm::none_of(aliasingOpResults,
+                    [&](OpResult opResult) { return isValueRead(opResult); }))
     return resultBuffer;
   // Do not copy if this op does not read the data, but writes it.
   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d03a287af5a0e..78e3ac8aba7c3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -140,7 +140,7 @@ bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
                                               BufferizationState &state) {
   markInPlace(operand);
-  if (OpResult result = state.getAliasingOpResult(operand))
+  for (OpResult result : state.getAliasingOpResult(operand))
     aliasInfo.unionSets(result, operand.get());
 }
 
@@ -196,8 +196,8 @@ AnalysisBufferizationState::AnalysisBufferizationState(
     for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
       if (opOperand.get().getType().isa<TensorType>())
         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
-          if (OpResult opResult =
-                  bufferizableOp.getAliasingOpResult(opOperand, *this))
+          for (OpResult opResult :
+               bufferizableOp.getAliasingOpResult(opOperand, *this))
             aliasInfo.unionAliasSets(opOperand.get(), opResult);
           aliasInfo.markInPlace(opOperand);
         }
@@ -404,7 +404,9 @@ static bool hasReadAfterWriteInterference(
 
         // No conflict if the conflicting write and the last write are the same
         // use.
-        if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
+        SmallVector<OpResult> aliasingOpResult =
+            state.getAliasingOpResult(*uConflictingWrite);
+        if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
           continue;
 
         // All requirements are met. Conflict found!
@@ -477,7 +479,7 @@ static bool wouldCreateReadAfterWriteInterference(
   DenseSet<OpOperand *> usesRead, usesWrite;
   getAliasingReads(usesRead, operand.get());
   getAliasingInplaceWrites(usesWrite, operand.get());
-  if (OpResult result = state.getAliasingOpResult(operand)) {
+  for (OpResult result : state.getAliasingOpResult(operand)) {
     getAliasingReads(usesRead, result);
     getAliasingInplaceWrites(usesWrite, result);
   }
@@ -506,7 +508,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
   bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
                   state.bufferizesToMemoryWrite(opOperand);
 
-  if (OpResult opResult = state.getAliasingOpResult(opOperand))
+  for (OpResult opResult : state.getAliasingOpResult(opOperand))
     hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
 
   return hasWrite;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 984b42b59c7f9..7a6a04fb03873 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -168,8 +168,7 @@ struct LinalgOpInterface
     // Operand is written to if it has an aliasing OpResult. For more details,
     // see `computeAliasingPairs`.
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    return static_cast<bool>(
-        bufferizableOp.getAliasingOpResult(opOperand, state));
+    return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
   }
 
   SmallVector<OpOperand *>
@@ -185,13 +184,16 @@ struct LinalgOpInterface
     return {};
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
 
     // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
-    return pairs[&opOperand];
+    if (!pairs.count(&opOperand))
+      return {};
+    return {pairs[&opOperand]};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -252,16 +254,19 @@ struct TiledLoopOpInterface
 
     // Only operands with an aliasing OpResult (i.e., output operands) bufferize
     // to a memory write.
-    return static_cast<bool>(
-        bufferizableOp.getAliasingOpResult(opOperand, state));
+    return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
 
     // Output operands are tied to their corresponding OpResults.
-    return tiledLoopOp.getTiedOpResult(opOperand);
+    OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand);
+    if (!opResult)
+      return {};
+    return {opResult};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -397,9 +402,10 @@ struct YieldOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 8f203a707beea..4abfec9c8fbdb 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -723,25 +723,24 @@ struct CallOpInterface
         funcOp.getArgument(opOperand.getOperandNumber()));
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     CallOp callOp = cast<CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
     const ModuleBufferizationState &moduleState =
         getModuleBufferizationState(state);
 
+    SmallVector<OpResult> result;
     for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
          ++resultIdx)
       if (Optional<int64_t> maybeArgNumber =
               getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
         if (*maybeArgNumber == opOperand.getOperandNumber())
-          return callOp->getOpResult(resultIdx);
+          result.push_back(callOp->getOpResult(resultIdx));
 
-    // Note: Returning a non-equivalent tensor from a FuncOp is currently not
-    // supported an will fail bufferization. (Even if allow-return-memref, it
-    // will fail when the function is called.)
-    return OpResult();
+    return result;
   }
 
   SmallVector<OpOperand *>
@@ -916,9 +915,10 @@ struct ReturnOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cc4147fdc2691..83a70e8dcf3af 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -278,12 +278,13 @@ struct ForOpInterface
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     auto forOp = cast<scf::ForOp>(op);
     if (!opOperand.get().getType().isa<RankedTensorType>())
-      return OpResult();
-    return forOp.getResultForOpOperand(opOperand);
+      return {};
+    return {forOp.getResultForOpOperand(opOperand)};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -401,13 +402,14 @@ struct YieldOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     if (isa<scf::IfOp>(op->getParentOp()))
-      return op->getParentOp()->getResult(opOperand.getOperandNumber());
+      return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
-      return op->getParentOp()->getResult(opOperand.getOperandNumber());
-    return OpResult();
+      return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
+    return {};
   }
 
   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index f3c9fb5aeb48f..91e10916124ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -35,9 +35,10 @@ struct CastOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return op->getResult(0);
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {op->getResult(0)};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -93,9 +94,10 @@ struct DimOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -121,11 +123,12 @@ struct ExtractSliceOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return &opOperand == &op->getOpOperand(0) /*source*/
-               ? op->getResult(0)
-               : OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    if (&opOperand == &op->getOpOperand(0) /*source*/)
+      return {op->getOpResult(0)};
+    return {};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -207,9 +210,10 @@ struct ExtractOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -371,11 +375,12 @@ struct InsertOpInterface
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
            "expected dest OpOperand");
-    return op->getOpResult(0);
+    return {op->getOpResult(0)};
   }
 
   SmallVector<OpOperand *>
@@ -451,11 +456,12 @@ struct InsertSliceOpInterface
     return &opOperand == &op->getOpOperand(1) /*dest*/;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return &opOperand == &op->getOpOperand(1) /*dest*/
-               ? op->getResult(0)
-               : OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    if (&opOperand == &op->getOpOperand(1) /*dest*/)
+      return {op->getResult(0)};
+    return {};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -606,9 +612,10 @@ struct RankOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 6252736001468..eecb7bc42eaa7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -40,9 +40,10 @@ struct TransferReadOpInterface
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    return {};
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -81,11 +82,12 @@ struct TransferWriteOpInterface
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
     assert(opOperand.get().getType().isa<TensorType>() &&
            "only tensor types expected");
-    return op->getOpResult(0);
+    return {op->getOpResult(0)};
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,


        


More information about the Mlir-commits mailing list