[Mlir-commits] [mlir] 5d50f51 - [mlir][bufferization][NFC] Add error handling to getBuffer

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 04:51:37 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T13:48:01+02:00
New Revision: 5d50f51c970f8bc7cb76c785f81cb13bab94d14e

URL: https://github.com/llvm/llvm-project/commit/5d50f51c970f8bc7cb76c785f81cb13bab94d14e
DIFF: https://github.com/llvm/llvm-project/commit/5d50f51c970f8bc7cb76c785f81cb13bab94d14e.diff

LOG: [mlir][bufferization][NFC] Add error handling to getBuffer

This is in preparation of adding memory space support.

Differential Revision: https://reviews.llvm.org/D128277

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f852e7bcfe961..b609a7fd78fb6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -479,14 +479,15 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
 /// from which the memref operand is returned.
-Value getBuffer(RewriterBase &rewriter, Value value,
-                const BufferizationOptions &options);
+FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
+                           const BufferizationOptions &options);
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 ///
 /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
 /// This function should only be used if `getBuffer` cannot be used.
-BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
+FailureOr<BaseMemRefType> getBufferType(Value value,
+                                        const BufferizationOptions &options);
 
 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
 /// must be replaced with memref values.

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 4c56a8196d455..8f739d758a91d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -343,7 +343,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Return the bufferized type of the given tensor block argument. The
           block argument is guaranteed to belong to a block of this op.
         }],
-        /*retType=*/"BaseMemRefType",
+        /*retType=*/"FailureOr<BaseMemRefType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "BlockArgument":$bbArg,
                       "const BufferizationOptions &":$options),

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index bd4b9d7d4a6be..24657bf3558bf 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -84,8 +84,10 @@ struct IndexCastOpInterface
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = castOp.getType().cast<TensorType>();
 
-    Value source = getBuffer(rewriter, castOp.getIn(), options);
-    auto sourceType = source.getType().cast<BaseMemRefType>();
+    FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+    if (failed(source))
+      return failure();
+    auto sourceType = source->getType().cast<BaseMemRefType>();
 
     // Result type should have same layout and address space as the source type.
     BaseMemRefType resultType;
@@ -100,7 +102,7 @@ struct IndexCastOpInterface
     }
 
     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
-                                                     source);
+                                                     *source);
     return success();
   }
 };
@@ -140,8 +142,14 @@ struct SelectOpInterface
     // instead of its OpOperands. In the worst case, 2 copies are inserted at
     // the moment (one for each tensor). When copying the op result, only one
     // copy would be needed.
-    Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
-    Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
+    FailureOr<Value> maybeTrueBuffer =
+        getBuffer(rewriter, selectOp.getTrueValue(), options);
+    FailureOr<Value> maybeFalseBuffer =
+        getBuffer(rewriter, selectOp.getFalseValue(), options);
+    if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
+      return failure();
+    Value trueBuffer = *maybeTrueBuffer;
+    Value falseBuffer = *maybeFalseBuffer;
 
     // The "true" and the "false" operands must have the same type. If the
     // buffers have 
diff erent types, they 
diff er only in their layout map. Cast

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8a01acaf9374a..7e5ccd031abeb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,8 +480,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
 #endif
 }
 
-Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
-                               const BufferizationOptions &options) {
+FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
+                                          const BufferizationOptions &options) {
 #ifndef NDEBUG
   auto tensorType = value.getType().dyn_cast<TensorType>();
   assert(tensorType && "unexpected non-tensor type");
@@ -494,14 +494,17 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
   // Insert to_memref op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  Type memrefType = getBufferType(value, options);
-  ensureToMemrefOpIsValid(value, memrefType);
-  return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
-                                                    value);
+  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+  if (failed(memrefType))
+    return failure();
+  ensureToMemrefOpIsValid(value, *memrefType);
+  return rewriter
+      .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
+      .getResult();
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-BaseMemRefType
+FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
   auto tensorType = value.getType().dyn_cast<TensorType>();
   assert(tensorType && "unexpected non-tensor type");

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 158f6edd1fa37..188c08b674c0e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -165,8 +165,12 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
 
   // Get "copy" buffer.
   Value copyBuffer;
-  if (getCopy())
-    copyBuffer = getBuffer(rewriter, getCopy(), options);
+  if (getCopy()) {
+    FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
+    if (failed(maybeCopyBuffer))
+      return failure();
+    copyBuffer = *maybeCopyBuffer;
+  }
 
   // Compute memory space of this allocation.
   unsigned memorySpace;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 3689522dd065e..b5d5cf3c52c23 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -305,8 +305,13 @@ struct CallOpInterface
 
       // Retrieve buffers for tensor operands.
       Value buffer = newOperands[idx];
-      if (!buffer)
-        buffer = getBuffer(rewriter, opOperand.get(), options);
+      if (!buffer) {
+        FailureOr<Value> maybeBuffer =
+            getBuffer(rewriter, opOperand.get(), options);
+        if (failed(maybeBuffer))
+          return failure();
+        buffer = *maybeBuffer;
+      }
 
       // Caller / callee type mismatch is handled with a CastOp.
       auto memRefType = funcType.getInput(idx);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index cc27b4403d898..a904d7c2e1d6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -44,15 +44,21 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
       newInputBuffers.push_back(opOperand->get());
       continue;
     }
-    newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
+    FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
+    if (failed(buffer))
+      return failure();
+    newInputBuffers.push_back(*buffer);
   }
 
   // New output operands for the cloned op.
   SmallVector<Value> newOutputBuffers;
   for (OpResult opResult : op->getOpResults()) {
     OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
-    Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
-    newOutputBuffers.push_back(resultBuffer);
+    FailureOr<Value> resultBuffer =
+        getBuffer(rewriter, opOperand->get(), options);
+    if (failed(resultBuffer))
+      return failure();
+    newOutputBuffers.push_back(*resultBuffer);
   }
 
   // Merge input/output operands.

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index d55e2784c1994..16ef7b5cad130 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -281,14 +281,17 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
 
 /// Helper function for loop bufferization. Return the bufferized values of the
 /// given OpOperands. If an operand is not a tensor, return the original value.
-static SmallVector<Value> getBuffers(RewriterBase &rewriter,
-                                     MutableArrayRef<OpOperand> operands,
-                                     const BufferizationOptions &options) {
+static FailureOr<SmallVector<Value>>
+getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+           const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
     if (opOperand.get().getType().isa<TensorType>()) {
-      Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
-      result.push_back(resultBuffer);
+      FailureOr<Value> resultBuffer =
+          getBuffer(rewriter, opOperand.get(), options);
+      if (failed(resultBuffer))
+        return failure();
+      result.push_back(*resultBuffer);
     } else {
       result.push_back(opOperand.get());
     }
@@ -298,36 +301,46 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,
 
 /// Helper function for loop bufferization. Compute the buffer that should be
 /// yielded from a loop block (loop body or loop condition).
-static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
-                              BaseMemRefType type,
-                              const BufferizationOptions &options) {
+static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
+                                         BaseMemRefType type,
+                                         const BufferizationOptions &options) {
   assert(tensor.getType().isa<TensorType>() && "expected tensor");
   ensureToMemrefOpIsValid(tensor, type);
-  Value yieldedVal = getBuffer(rewriter, tensor, options);
-  return castBuffer(rewriter, yieldedVal, type);
+  FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
+  if (failed(yieldedVal))
+    return failure();
+  return castBuffer(rewriter, *yieldedVal, type);
 }
 
 /// Helper function for loop bufferization. Given a range of values, apply
 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
 /// value in the result vector.
-static SmallVector<Value>
+static FailureOr<SmallVector<Value>>
 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
-                    llvm::function_ref<Value(Value, int64_t)> func) {
+                    llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
   SmallVector<Value> result;
   for (const auto &it : llvm::enumerate(values)) {
     size_t idx = it.index();
     Value val = it.value();
-    result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
+    if (tensorIndices.contains(idx)) {
+      FailureOr<Value> maybeVal = func(val, idx);
+      if (failed(maybeVal))
+        return failure();
+      result.push_back(*maybeVal);
+    } else {
+      result.push_back(val);
+    }
   }
   return result;
 }
 
 /// Helper function for loop bufferization. Given a list of pre-bufferization
 /// yielded values, compute the list of bufferized yielded values.
-SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
-                                    TypeRange bufferizedTypes,
-                                    const DenseSet<int64_t> &tensorIndices,
-                                    const BufferizationOptions &options) {
+FailureOr<SmallVector<Value>>
+getYieldedValues(RewriterBase &rewriter, ValueRange values,
+                 TypeRange bufferizedTypes,
+                 const DenseSet<int64_t> &tensorIndices,
+                 const BufferizationOptions &options) {
   return convertTensorValues(
       values, tensorIndices, [&](Value val, int64_t index) {
         return getYieldedBuffer(rewriter, val,
@@ -342,10 +355,19 @@ SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
 SmallVector<Value>
 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
                      const DenseSet<int64_t> &tensorIndices) {
-  return convertTensorValues(
-      bbArgs, tensorIndices, [&](Value val, int64_t index) {
-        return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
-      });
+  SmallVector<Value> result;
+  for (const auto &it : llvm::enumerate(bbArgs)) {
+    size_t idx = it.index();
+    Value val = it.value();
+    if (tensorIndices.contains(idx)) {
+      result.push_back(
+          rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
+              .getResult());
+    } else {
+      result.push_back(val);
+    }
+  }
+  return result;
 }
 
 /// Bufferization of scf.for. Replace with a new scf.for that operates on
@@ -445,8 +467,9 @@ struct ForOpInterface
     return success();
   }
 
-  BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg,
-                               const BufferizationOptions &options) const {
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, BlockArgument bbArg,
+                const BufferizationOptions &options) const {
     auto forOp = cast<scf::ForOp>(op);
     return bufferization::getBufferType(
         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
@@ -462,8 +485,11 @@ struct ForOpInterface
     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
 
     // The new memref init_args of the loop.
-    SmallVector<Value> initArgs =
+    FailureOr<SmallVector<Value>> maybeInitArgs =
         getBuffers(rewriter, forOp.getIterOpOperands(), options);
+    if (failed(maybeInitArgs))
+      return failure();
+    SmallVector<Value> initArgs = *maybeInitArgs;
 
     // Construct a new scf.for op with memref instead of tensor values.
     auto newForOp = rewriter.create<scf::ForOp>(
@@ -689,13 +715,17 @@ struct WhileOpInterface
         getTensorIndices(whileOp.getAfterArguments());
 
     // The new memref init_args of the loop.
-    SmallVector<Value> initArgs =
+    FailureOr<SmallVector<Value>> maybeInitArgs =
         getBuffers(rewriter, whileOp->getOpOperands(), options);
+    if (failed(maybeInitArgs))
+      return failure();
+    SmallVector<Value> initArgs = *maybeInitArgs;
 
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
-          return bufferization::getBufferType(bbArg, options).cast<Type>();
+          // TODO: error handling
+          return bufferization::getBufferType(bbArg, options)->cast<Type>();
         }));
 
     // Construct a new scf.while op with memref instead of tensor values.
@@ -727,10 +757,12 @@ struct WhileOpInterface
     // Only equivalent buffers or new buffer allocations may be yielded to the
     // "after" region.
     // TODO: This could be relaxed for better bufferization results.
-    SmallVector<Value> newConditionArgs =
+    FailureOr<SmallVector<Value>> newConditionArgs =
         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
                          indicesAfter, options);
-    newConditionOp.getArgsMutable().assign(newConditionArgs);
+    if (failed(newConditionArgs))
+      return failure();
+    newConditionOp.getArgsMutable().assign(*newConditionArgs);
 
     // Set up new iter_args and move the loop body block to the new op.
     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
@@ -746,10 +778,12 @@ struct WhileOpInterface
     // Only equivalent buffers or new buffer allocations may be yielded to the
     // "before" region.
     // TODO: This could be relaxed for better bufferization results.
-    SmallVector<Value> newYieldValues =
+    FailureOr<SmallVector<Value>> newYieldValues =
         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
                          indicesBefore, options);
-    newYieldOp.getResultsMutable().assign(newYieldValues);
+    if (failed(newYieldValues))
+      return failure();
+    newYieldOp.getResultsMutable().assign(*newYieldValues);
 
     // Replace loop results.
     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
@@ -849,13 +883,18 @@ struct YieldOpInterface
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Value value = it.value();
       if (value.getType().isa<TensorType>()) {
-        Value buffer = getBuffer(rewriter, value, options);
+        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+        if (failed(maybeBuffer))
+          return failure();
+        Value buffer = *maybeBuffer;
         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
-          BaseMemRefType resultType =
+          FailureOr<BaseMemRefType> resultType =
               cast<BufferizableOpInterface>(forOp.getOperation())
                   .getBufferType(forOp.getRegionIterArgs()[it.index()],
                                  options);
-          buffer = castBuffer(rewriter, buffer, resultType);
+          if (failed(resultType))
+            return failure();
+          buffer = castBuffer(rewriter, buffer, *resultType);
         }
         newResults.push_back(buffer);
       } else {
@@ -1078,16 +1117,22 @@ struct ParallelInsertSliceOpInterface
     // If the op bufferizes out-of-place, allocate the copy before the
     // ForeachThreadOp.
     rewriter.setInsertionPoint(foreachThreadOp);
-    Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options);
+    FailureOr<Value> destBuffer =
+        getBuffer(rewriter, insertOp.getDest(), options);
+    if (failed(destBuffer))
+      return failure();
 
     // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
     rewriter.setInsertionPoint(performConcurrentlyOp);
-    Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options);
+    FailureOr<Value> srcBuffer =
+        getBuffer(rewriter, insertOp.getSource(), options);
+    if (failed(srcBuffer))
+      return failure();
     Value subview = rewriter.create<memref::SubViewOp>(
-        insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(),
+        insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
         insertOp.getMixedSizes(), insertOp.getMixedStrides());
     // This memcpy will fold away if everything bufferizes in-place.
-    if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer,
+    if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
                                     subview)))
       return failure();
     rewriter.eraseOp(op);
@@ -1095,7 +1140,7 @@ struct ParallelInsertSliceOpInterface
     // Replace all uses of ForeachThreadOp (just the corresponding result).
     rewriter.setInsertionPointAfter(foreachThreadOp);
     Value toTensorOp =
-        rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), destBuffer);
+        rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
     unsigned resultNum = 0;
     for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
       if (&nextOp == op)

diff  --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 9a4f8e187a8b7..68580b8680afc 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -130,10 +130,16 @@ struct AssumingYieldOpInterface
                           const BufferizationOptions &options) const {
     auto yieldOp = cast<shape::AssumingYieldOp>(op);
     SmallVector<Value> newResults;
-    for (Value value : yieldOp.operands())
-      newResults.push_back(value.getType().isa<TensorType>()
-                               ? getBuffer(rewriter, value, options)
-                               : value);
+    for (Value value : yieldOp.operands()) {
+      if (value.getType().isa<TensorType>()) {
+        FailureOr<Value> buffer = getBuffer(rewriter, value, options);
+        if (failed(buffer))
+          return failure();
+        newResults.push_back(*buffer);
+      } else {
+        newResults.push_back(value);
+      }
+    }
     replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
                                                          newResults);
     return success();

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 7f8b6a35491d5..e7e31dcd42f54 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,8 +52,11 @@ struct CastOpInterface
     auto castOp = cast<tensor::CastOp>(op);
 
     // The result buffer still has the old (pre-cast) type.
-    Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
-    auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
+    FailureOr<Value> resultBuffer =
+        getBuffer(rewriter, castOp.getSource(), options);
+    if (failed(resultBuffer))
+      return failure();
+    auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
     TensorType resultTensorType =
         castOp.getResult().getType().cast<TensorType>();
     MemRefLayoutAttrInterface layout;
@@ -68,11 +71,11 @@ struct CastOpInterface
                       sourceMemRefType.getMemorySpaceAsInt());
 
     // Replace the op with a memref.cast.
-    assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
+    assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
                                              resultMemRefType) &&
            "CallOp::bufferize: cast incompatible");
     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
-                                                 resultBuffer);
+                                                 *resultBuffer);
 
     return success();
   }
@@ -108,7 +111,11 @@ struct CollapseShapeOpInterface
                           const BufferizationOptions &options) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
-    Value buffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options);
+    FailureOr<Value> maybeBuffer =
+        getBuffer(rewriter, collapseShapeOp.getSrc(), options);
+    if (failed(maybeBuffer))
+      return failure();
+    Value buffer = *maybeBuffer;
     auto bufferType = buffer.getType().cast<MemRefType>();
 
     if (tensorResultType.getRank() == 0) {
@@ -187,9 +194,11 @@ struct DimOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    auto v = getBuffer(rewriter, dimOp.getSource(), options);
-    replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v,
-                                                dimOp.getIndex());
+    FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
+    if (failed(v))
+      return failure();
+    replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
+                                                dimOp.index());
     return success();
   }
 };
@@ -224,12 +233,15 @@ struct ExpandShapeOpInterface
                           const BufferizationOptions &options) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto tensorResultType = expandShapeOp.getResultType();
-    auto buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options);
+    FailureOr<Value> buffer =
+        getBuffer(rewriter, expandShapeOp.getSrc(), options);
+    if (failed(buffer))
+      return failure();
 
     // Memref result type is inferred by the builder based on reassociation
     // indices and result shape.
     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, tensorResultType.getShape(), buffer,
+        rewriter, op, tensorResultType.getShape(), *buffer,
         expandShapeOp.getReassociationIndices());
     return success();
   }
@@ -268,8 +280,11 @@ struct ExtractSliceOpInterface
 
     // Even if this op was decided to bufferize out-of-place, do not insert the
     // buffer copy yet. This is done later in this function.
-    auto srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options);
-    auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+    FailureOr<Value> srcMemref =
+        getBuffer(rewriter, extractSliceOp.getSource(), options);
+    if (failed(srcMemref))
+      return failure();
+    auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.getResult().getType().cast<RankedTensorType>();
 
@@ -279,7 +294,7 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     OffsetSizeAndStrideOpInterface::expandToRank(
-        srcMemref, mixedOffsets, mixedSizes, mixedStrides,
+        *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
         [&](Value target, int64_t dim) -> OpFoldResult {
           auto shapedType = target.getType().cast<ShapedType>();
           if (shapedType.isDynamicDim(dim))
@@ -292,7 +307,7 @@ struct ExtractSliceOpInterface
                                  mixedOffsets, mixedSizes, mixedStrides)
                                  .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
+        loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
         mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
@@ -322,9 +337,12 @@ struct ExtractOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-    Value srcMemref = getBuffer(rewriter, extractOp.getTensor(), options);
-    replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
-                                                 extractOp.getIndices());
+    FailureOr<Value> srcMemref =
+        getBuffer(rewriter, extractOp.getTensor(), options);
+    if (failed(srcMemref))
+      return failure();
+    replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
+                                                 extractOp.indices());
     return success();
   }
 };
@@ -497,10 +515,13 @@ struct InsertOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto insertOp = cast<tensor::InsertOp>(op);
-    Value destMemref = getBuffer(rewriter, insertOp.getDest(), options);
+    FailureOr<Value> destMemref =
+        getBuffer(rewriter, insertOp.getDest(), options);
+    if (failed(destMemref))
+      return failure();
     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
-                                     destMemref, insertOp.getIndices());
-    replaceOpWithBufferizedValues(rewriter, op, destMemref);
+                                     *destMemref, insertOp.getIndices());
+    replaceOpWithBufferizedValues(rewriter, op, *destMemref);
     return success();
   }
 
@@ -655,7 +676,10 @@ struct InsertSliceOpInterface
     // TODO: be very loud about it or even consider failing the pass.
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
     Location loc = insertSliceOp.getLoc();
-    Value dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options);
+    FailureOr<Value> dstMemref =
+        getBuffer(rewriter, insertSliceOp.getDest(), options);
+    if (failed(dstMemref))
+      return failure();
 
     // Expand offsets, sizes and strides to the full rank to handle the
     // rank-reducing case.
@@ -663,7 +687,7 @@ struct InsertSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
     OffsetSizeAndStrideOpInterface::expandToRank(
-        dstMemref, mixedOffsets, mixedSizes, mixedStrides,
+        *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
         [&](Value target, int64_t dim) -> OpFoldResult {
           auto shapedType = target.getType().cast<ShapedType>();
           if (shapedType.isDynamicDim(dim))
@@ -671,23 +695,26 @@ struct InsertSliceOpInterface
           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
         });
     // Take a subview of the dst.
-    auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+    auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
     auto subviewMemRefType =
         memref::SubViewOp::inferRankReducedResultType(
             insertSliceOp.getSourceType().getRank(), dstMemrefType,
             mixedOffsets, mixedSizes, mixedStrides)
             .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes,
+        loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
         mixedStrides);
 
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.
-    auto srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options);
-    if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
+    FailureOr<Value> srcMemref =
+        getBuffer(rewriter, insertSliceOp.getSource(), options);
+    if (failed(srcMemref))
+      return failure();
+    if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
       return failure();
 
-    replaceOpWithBufferizedValues(rewriter, op, dstMemref);
+    replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
     return success();
   }
 };
@@ -714,9 +741,11 @@ struct RankOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto rankOp = cast<tensor::RankOp>(op);
-    auto v = getBuffer(rewriter, rankOp.getTensor(), options);
+    FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
+    if (failed(v))
+      return failure();
     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
-                                                 v);
+                                                 *v);
     return success();
   }
 };
@@ -750,12 +779,16 @@ struct ReshapeOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
-    Value srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options);
-    Value shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options);
+    FailureOr<Value> srcBuffer =
+        getBuffer(rewriter, reshapeOp.getSource(), options);
+    FailureOr<Value> shapeBuffer =
+        getBuffer(rewriter, reshapeOp.getShape(), options);
+    if (failed(srcBuffer) || failed(shapeBuffer))
+      return failure();
     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
     auto resultMemRefType = getMemRefType(resultTensorType, options);
     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
-        rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
+        rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 142e09bbb3da5..77f895a929c33 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -50,9 +50,11 @@ struct TransferReadOpInterface
     auto readOp = cast<vector::TransferReadOp>(op);
     assert(readOp.getShapedType().isa<TensorType>() &&
            "only tensor types expected");
-    Value buffer = getBuffer(rewriter, readOp.getSource(), options);
+    FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
+    if (failed(buffer))
+      return failure();
     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
-        rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
+        rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
         readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
         readOp.getInBoundsAttr());
     return success();
@@ -97,12 +99,15 @@ struct TransferWriteOpInterface
            "only tensor types expected");
 
     // Create a new transfer_write on buffer that doesn't have a return value.
-    Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options);
+    FailureOr<Value> resultBuffer =
+        getBuffer(rewriter, writeOp.getSource(), options);
+    if (failed(resultBuffer))
+      return failure();
     rewriter.create<vector::TransferWriteOp>(
-        writeOp.getLoc(), writeOp.getVector(), resultBuffer,
+        writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
         writeOp.getIndices(), writeOp.getPermutationMapAttr(),
         writeOp.getInBoundsAttr());
-    replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
+    replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
 
     return success();
   }


        


More information about the Mlir-commits mailing list