[Mlir-commits] [mlir] d9184ab - [mlir][linalg][bufferize][NFC] Simplify buffer API of BufferizationState

Matthias Springer llvmlistbot at llvm.org
Fri Jan 7 08:16:26 PST 2022


Author: Matthias Springer
Date: 2022-01-08T01:12:18+09:00
New Revision: d9184ab1a53ae008722fc569d147b1c88acbbd9d

URL: https://github.com/llvm/llvm-project/commit/d9184ab1a53ae008722fc569d147b1c88acbbd9d
DIFF: https://github.com/llvm/llvm-project/commit/d9184ab1a53ae008722fc569d147b1c88acbbd9d.diff

LOG: [mlir][linalg][bufferize][NFC] Simplify buffer API of BufferizationState

Instead of `lookupBuffer` and `getResultBuffer`, there is now a single `getBuffer` function. This simplifies the `BufferizableOpInterface` API and is less confusing to users. They could previously have called the wrong function.

Furthermore, since `getBuffer` now takes an `OpOperand &` instead of a `Value`, users can no longer accidentally use one of the previous two functions incorrectly, which would have resulted in missing buffer copies.

Differential Revision: https://reviews.llvm.org/D116455

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index d5fae8925ffd..22bf5e645049 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -377,18 +377,14 @@ class BufferizationState {
   /// Creates a memcpy between two given buffers.
   void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
 
-  /// Lookup the memref buffer that is associated to the given tensor value.
-  /// Asserts if no buffer is associated.
-  Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
-
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   bool isInPlace(OpOperand &opOperand) const;
 
-  /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+  /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
-  /// bufferization is necessary.
-  FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
-                                   OpResult result) const;
+  /// bufferization was decided.
+  FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
+                             bool forceInPlace = false) const;
 
   /// Return dialect-specific bufferization state.
   template <typename StateT>

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e3ea2c805443..d2d726312d6a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -347,74 +347,73 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
   });
 }
 
+static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
+  assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+
+  // Replace "%t = to_tensor %m" with %m.
+  if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
+    return toTensorOp.memref();
+
+  // Insert to_memref op.
+  OpBuilder::InsertionGuard g(rewriter);
+  setInsertionPointAfter(rewriter, tensor);
+  Type memrefType;
+  if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
+    memrefType = getDynamicMemRefType(rankedTensorType);
+  } else {
+    memrefType = getUnrankedMemRefType(
+        tensor.getType().cast<TensorType>().getElementType());
+  }
+  return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
+                                                    tensor);
+}
+
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
 FailureOr<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
-    RewriterBase &rewriter, OpResult result) const {
+mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
+    RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
   OpBuilder::InsertionGuard guard(rewriter);
-  Operation *op = result.getOwner();
-  SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
-  assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
-  OpOperand *opOperand = aliasingOperands.front();
-  Value operand = opOperand->get();
+  Operation *op = opOperand.getOwner();
+  Location loc = op->getLoc();
+  Value operand = opOperand.get();
   Value operandBuffer = lookupBuffer(rewriter, operand);
-  // Make sure that all OpOperands are the same buffer. If this is not the case,
-  // we would have to materialize a memref value.
-  // TODO: Should be looking for checking for "equivalent buffers" instead of
-  // operator== here, but equivalent buffers for scf.if yield values are not
-  // set up yet.
-  if (aliasingOperands.size() > 1 &&
-      !llvm::all_of(aliasingOperands, [&](OpOperand *o) {
-        return lookupBuffer(rewriter, o->get()) == operandBuffer;
-      }))
-    return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
-
-  // If bufferizing out-of-place, allocate a new buffer.
-  if (!aliasInfo.isInPlace(*opOperand)) {
-    // Ops with multiple aliasing operands can currently not bufferize
-    // out-of-place.
-    assert(
-        aliasingOperands.size() == 1 &&
-        "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
-    Location loc = op->getLoc();
-    // Move insertion point right after `operandBuffer`. That is where the
-    // allocation should be inserted (in the absence of allocation hoisting).
-    setInsertionPointAfter(rewriter, operandBuffer);
-    // Allocate the result buffer.
-    FailureOr<Value> resultBuffer =
-        createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
-    if (failed(resultBuffer))
-      return failure();
-    bool skipCopy = false;
-    // Do not copy if the last preceding write of `operand` is an op that does
-    // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
-    // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
-    // use-def chain, it returns that value, regardless of whether it is a
-    // memory write or not.
-    Value lastWrite = findLastPrecedingWrite(operand);
-    if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
-      if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
-        skipCopy = true;
-    // Do not copy if the copied data is never read. (Neither by this op nor by
-    // any following op.)
-    if (!bufferizesToMemoryRead(*opOperand) && !isValueRead(result))
-      skipCopy = true;
-    // Do not copy if this op does not read the data, but writes it.
-    if (bufferizesToMemoryWrite(*opOperand) &&
-        !bufferizesToMemoryRead(*opOperand))
-      skipCopy = true;
-    if (!skipCopy) {
-      // The copy happens right before the op that is bufferized.
-      rewriter.setInsertionPoint(op);
-      createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
-    }
+
+  if (forceInPlace || aliasInfo.isInPlace(opOperand))
+    return operandBuffer;
+
+  // Bufferizing out-of-place: Allocate a new buffer.
+  // Move insertion point right after `operandBuffer`. That is where the
+  // allocation should be inserted (in the absence of allocation hoisting).
+  setInsertionPointAfter(rewriter, operandBuffer);
+  // Allocate the result buffer.
+  FailureOr<Value> resultBuffer =
+      createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
+  if (failed(resultBuffer))
+    return failure();
+  // Do not copy if the last preceding write of `operand` is an op that does
+  // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
+  // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
+  // use-def chain, it returns that value, regardless of whether it is a
+  // memory write or not.
+  Value lastWrite = findLastPrecedingWrite(operand);
+  if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
+    if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
+      return resultBuffer;
+  // Do not copy if the copied data is never read.
+  OpResult aliasingOpResult = getAliasingOpResult(opOperand);
+  if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
+      !isValueRead(aliasingOpResult))
+    return resultBuffer;
+  // Do not copy if this op does not read the data, but writes it.
+  if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
     return resultBuffer;
-  }
 
-  // Bufferizing in-place. No need to allocate a new buffer.
-  return operandBuffer;
+  // The copy happens right before the op that is bufferized.
+  rewriter.setInsertionPoint(op);
+  createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
+  return resultBuffer;
 }
 
 void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
@@ -593,28 +592,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
   return isa<FuncOp>(bbArg.getOwner()->getParentOp());
 }
 
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
-    RewriterBase &rewriter, Value tensor) const {
-  assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
-
-  // Replace "%t = to_tensor %m" with %m.
-  if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
-    return toTensorOp.memref();
-
-  // Insert to_memref op.
-  OpBuilder::InsertionGuard g(rewriter);
-  setInsertionPointAfter(rewriter, tensor);
-  Type memrefType;
-  if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
-    memrefType = getDynamicMemRefType(rankedTensorType);
-  } else {
-    memrefType = getUnrankedMemRefType(
-        tensor.getType().cast<TensorType>().getElementType());
-  }
-  return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
-                                                    tensor);
-}
-
 bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
     OpOperand &opOperand) const {
   return aliasInfo.isInPlace(opOperand);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 2bdf6e757c98..0693dee0c405 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -46,15 +46,19 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
       newInputBuffers.push_back(opOperand->get());
       continue;
     }
-    newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
+    // Input operands are never written to.
+    newInputBuffers.push_back(
+        *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
   }
 
   // New output operands for the cloned op.
   SmallVector<Value> newOutputBuffers;
-  for (OpOperand *opOperand : op.getOutputOperands()) {
-    OpResult opResult = op.getTiedOpResult(opOperand);
-    assert(opResult && "could not find correspond OpResult");
-    FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
+  for (OpResult opResult : op->getOpResults()) {
+    SmallVector<OpOperand *> aliasingOpOperands =
+        state.getAliasingOpOperand(opResult);
+    assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
+    FailureOr<Value> resultBuffer =
+        state.getBuffer(rewriter, *aliasingOpOperands.front());
     if (failed(resultBuffer))
       return failure();
     newOutputBuffers.push_back(*resultBuffer);
@@ -284,24 +288,23 @@ struct TiledLoopOpInterface
 
     // Compute new inputs, outputs and results.
     SmallVector<Value> newInputs, newOutputs, newResults;
-    for (Value value : tiledLoopOp.inputs()) {
-      if (value.getType().isa<TensorType>()) {
-        newInputs.push_back(state.lookupBuffer(rewriter, value));
-      } else {
-        newInputs.push_back(value);
-      }
-    }
-    int nextResultNum = 0;
-    for (Value value : tiledLoopOp.outputs()) {
-      if (value.getType().isa<TensorType>()) {
-        FailureOr<Value> buffer = state.getResultBuffer(
-            rewriter, tiledLoopOp->getResult(nextResultNum++));
-        if (failed(buffer))
+    for (int i = tiledLoopOp.getNumControlOperands();
+         i < tiledLoopOp->getNumOperands(); ++i) {
+      OpOperand &operand = tiledLoopOp->getOpOperand(i);
+      Value rewrittenValue = operand.get();
+      if (rewrittenValue.getType().isa<TensorType>()) {
+        FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand);
+        if (failed(bufferOrFailure))
           return failure();
-        newOutputs.push_back(*buffer);
-        newResults.push_back(*buffer);
+        rewrittenValue = *bufferOrFailure;
+      }
+      if (i <
+          tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) {
+        newInputs.push_back(rewrittenValue);
       } else {
-        newOutputs.push_back(value);
+        newOutputs.push_back(rewrittenValue);
+        if (operand.get().getType().isa<TensorType>())
+          newResults.push_back(rewrittenValue);
       }
     }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index c49f45da13c5..bb9a06c25150 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -351,7 +351,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
 
     // Cast values at the call site if necessary.
     returnValues.push_back(
-        getNonCastedValue(state.lookupBuffer(rewriter, returnVal)));
+        getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
   }
 
   // 2. Rewrite the terminator without the inPlace bufferizable values.
@@ -659,7 +659,8 @@ struct CallOpInterface
         // Return operands that are equivalent to some bbArg, are not
         // returned.
         Value buffer =
-            state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
+            *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx),
+                             /*forceInPlace=*/true);
         replacementValues[returnValIdx] = buffer;
         newOperands[*bbArgIdx] = buffer;
         continue;
@@ -690,9 +691,9 @@ struct CallOpInterface
       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
       // already stored in `newOperands` during Step 1.
-      Value buffer = newOperands[idx]
-                         ? newOperands[idx]
-                         : state.lookupBuffer(rewriter, tensorOperand);
+      Value buffer = newOperands[idx] ? newOperands[idx]
+                                      : *state.getBuffer(rewriter, opOperand,
+                                                         /*forceInPlace=*/true);
 
       // Caller / callee type mistmatch is handled with a CastOp.
       auto memRefType = bufferizedFuncType.getInput(idx);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 62fdd5a78051..01308088bab8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -280,19 +280,17 @@ struct ForOpInterface
     };
 
     // Construct a new scf.for op with memref instead of tensor values.
-    bool resultBufferFailure = false;
-    SmallVector<Value> initArgs =
-        convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
-          FailureOr<Value> resultBuffer =
-              state.getResultBuffer(rewriter, forOp->getOpResult(index));
-          if (failed(resultBuffer)) {
-            resultBufferFailure = true;
-            return Value();
-          }
-          return *resultBuffer;
-        });
-    if (resultBufferFailure)
-      return failure();
+    SmallVector<Value> initArgs;
+    for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+      if (opOperand.get().getType().isa<TensorType>()) {
+        FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
+        if (failed(resultBuffer))
+          return failure();
+        initArgs.push_back(*resultBuffer);
+      } else {
+        initArgs.push_back(opOperand.get());
+      }
+    }
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), initArgs);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 86df686239d6..7c9114b284b2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -53,7 +53,7 @@ struct CastOpInterface
 
     // The result buffer still has the old (pre-cast) type.
     FailureOr<Value> resultBuffer =
-        state.getResultBuffer(rewriter, castOp->getResult(0));
+        state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
     if (failed(resultBuffer))
       return failure();
     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
@@ -106,7 +106,7 @@ struct DimOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    Value v = state.lookupBuffer(rewriter, dimOp.source());
+    Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
     return success();
   }
@@ -143,7 +143,9 @@ struct ExtractSliceOpInterface
                           const BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     Location loc = extractSliceOp.getLoc();
-    Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
+    Value srcMemref =
+        *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
+                         /*forceInPlace=*/true);
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();
@@ -206,7 +208,8 @@ struct ExtractOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-    Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
+    Value srcMemref =
+        *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
                                                  extractOp.indices());
     return success();
@@ -244,7 +247,7 @@ struct InsertOpInterface
                           const BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
     FailureOr<Value> destMemref =
-        state.getResultBuffer(rewriter, insertOp->getOpResult(0));
+        state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
     if (failed(destMemref))
       return failure();
     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
@@ -412,7 +415,7 @@ struct InsertSliceOpInterface
 
     // When bufferizing out-of-place, `getResultBuffer` allocates.
     FailureOr<Value> dstMemref =
-        state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
+        state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
     if (failed(dstMemref))
       return failure();
 
@@ -430,7 +433,8 @@ struct InsertSliceOpInterface
 
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.
-    Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
+    Value srcMemref =
+        *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
     state.createMemCpy(rewriter, loc, srcMemref, subView);
 
     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 58013323cb70..0b3d8ff6d266 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -48,7 +48,8 @@ struct TransferReadOpInterface
            "only tensor types expected");
 
     // TransferReadOp always reads from the bufferized op.source().
-    Value buffer = state.lookupBuffer(rewriter, readOp.source());
+    Value buffer =
+        *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
         rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
         readOp.permutation_map(), readOp.padding(), readOp.mask(),
@@ -99,7 +100,7 @@ struct TransferWriteOpInterface
     // Leave the previous transfer_write to dead code as it still has uses at
     // this point.
     FailureOr<Value> resultBuffer =
-        state.getResultBuffer(rewriter, op->getResult(0));
+        state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
     if (failed(resultBuffer))
       return failure();
     rewriter.create<vector::TransferWriteOp>(


        


More information about the Mlir-commits mailing list