[Mlir-commits] [mlir] 698896c - [mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>

Matthias Springer llvmlistbot at llvm.org
Thu Jan 6 13:49:40 PST 2022


Author: Matthias Springer
Date: 2022-01-07T06:33:19+09:00
New Revision: 698896cd6c8cc5e865e1715e7c9d82295f82745b

URL: https://github.com/llvm/llvm-project/commit/698896cd6c8cc5e865e1715e7c9d82295f82745b
DIFF: https://github.com/llvm/llvm-project/commit/698896cd6c8cc5e865e1715e7c9d82295f82745b.diff

LOG: [mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>

In addition, all functions that call `allocationFn` now return FailureOr<Value>. This resolves a few TODOs in the code base.

Differential Revision: https://reviews.llvm.org/D116452

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 921353a23ea7..c18f7f9fc5e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -41,7 +41,7 @@ struct PostAnalysisStep;
 // TODO: Could be replaced with a "bufferization strategy" object with virtual
 // functions in the future.
 struct AllocationCallbacks {
-  using AllocationFn = std::function<Optional<Value>(
+  using AllocationFn = std::function<FailureOr<Value>(
       OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
   using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
   using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
@@ -360,15 +360,15 @@ class BufferizationState {
   Value findLastPrecedingWrite(Value value) const;
 
   /// Creates a memref allocation.
-  Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                              ArrayRef<Value> dynShape) const;
+  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                               ArrayRef<Value> dynShape) const;
 
   /// Creates a memref allocation for the given shaped value. This function may
   /// perform additional optimizations such as buffer allocation hoisting. If
   /// `createDealloc`, a deallocation op is inserted at the point where the
   /// allocation goes out of scope.
-  Value createAlloc(OpBuilder &b, Location loc, Value shapedValue,
-                    bool deallocMemref) const;
+  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+                               bool deallocMemref) const;
 
   /// Creates a memref deallocation. The given memref buffer must have been
   /// allocated using `createAlloc`.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 118e25a23148..b2a58069e85a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -41,9 +41,9 @@ using namespace linalg::comprehensive_bufferize;
 
 /// Default allocation function that is used by the comprehensive bufferization
 /// pass. The default currently creates a ranked memref using `memref.alloc`.
-static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
-                                           MemRefType type,
-                                           ArrayRef<Value> dynShape) {
+static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
+                                            MemRefType type,
+                                            ArrayRef<Value> dynShape) {
   Value allocated = b.create<memref::AllocOp>(
       loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;
@@ -391,8 +391,10 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
     // allocation should be inserted (in the absence of allocation hoisting).
     setInsertionPointAfter(rewriter, operandBuffer);
     // Allocate the result buffer.
-    Value resultBuffer =
+    FailureOr<Value> resultBuffer =
         createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
+    if (failed(resultBuffer))
+      return failure();
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -413,7 +415,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
     if (!skipCopy) {
       // The copy happens right before the op that is bufferized.
       rewriter.setInsertionPoint(op);
-      createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
+      createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
     }
     return resultBuffer;
   }
@@ -537,7 +539,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
 /// Create an AllocOp/DeallocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
+FailureOr<Value>
+mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
     OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
@@ -549,10 +552,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
   // Note: getAllocationTypeAndShape also sets the insertion point.
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
-  // TODO: For now just assert the value is returned. Eventually need to
-  // error-propagate.
-  assert(allocated && "allocation failed");
+  FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
+  if (failed(allocated))
+    return failure();
   Value casted = allocated.getValue();
   if (memRefType && memRefType != allocMemRefType) {
     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
@@ -568,7 +570,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
 }
 
 /// Create a memref allocation.
-Optional<Value>
+FailureOr<Value>
 mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
     OpBuilder &b, Location loc, MemRefType type,
     ArrayRef<Value> dynShape) const {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index c4f42afb9828..a3cb3c36065b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -55,6 +55,8 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
     OpResult opResult = op.getTiedOpResult(opOperand);
     assert(opResult && "could not find correspond OpResult");
     FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
+    if (failed(resultBuffer))
+      return failure();
     newOutputBuffers.push_back(*resultBuffer);
   }
 
@@ -210,10 +212,12 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
-                                    initTensorOp.result(),
-                                    state.getOptions().createDeallocs);
-    replaceOpWithBufferizedValues(rewriter, op, alloc);
+    FailureOr<Value> alloc = state.createAlloc(
+        rewriter, initTensorOp->getLoc(), initTensorOp.result(),
+        state.getOptions().createDeallocs);
+    if (failed(alloc))
+      return failure();
+    replaceOpWithBufferizedValues(rewriter, op, *alloc);
     return success();
   }
 };
@@ -287,6 +291,8 @@ struct TiledLoopOpInterface
       if (value.getType().isa<TensorType>()) {
         FailureOr<Value> buffer = state.getResultBuffer(
             rewriter, tiledLoopOp->getResult(nextResultNum++));
+        if (failed(buffer))
+          return failure();
         newOutputs.push_back(*buffer);
         newResults.push_back(*buffer);
       } else {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 5983d421aaed..1d62c7880a31 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -295,10 +295,19 @@ struct ForOpInterface
     };
 
     // Construct a new scf.for op with memref instead of tensor values.
+    bool resultBufferFailure = false;
     SmallVector<Value> initArgs =
         convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
-          return *state.getResultBuffer(rewriter, forOp->getOpResult(index));
+          FailureOr<Value> resultBuffer =
+              state.getResultBuffer(rewriter, forOp->getOpResult(index));
+          if (failed(resultBuffer)) {
+            resultBufferFailure = true;
+            return Value();
+          }
+          return *resultBuffer;
         });
+    if (resultBufferFailure)
+      return failure();
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), initArgs);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 6b8b8983972a..b6ee0fc63471 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -54,6 +54,8 @@ struct CastOpInterface
     // The result buffer still has the old (pre-cast) type.
     FailureOr<Value> resultBuffer =
         state.getResultBuffer(rewriter, castOp->getResult(0));
+    if (failed(resultBuffer))
+      return failure();
     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
     Attribute memorySpace = sourceMemRefType.getMemorySpace();
     TensorType resultTensorType =
@@ -149,9 +151,14 @@ struct ExtractSliceOpInterface
     // If not inplaceable, alloc.
     bool inplace = state.isInPlace(extractSliceOp->getResult(0));
     Value alloc;
-    if (!inplace)
-      alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(),
-                                state.getOptions().createDeallocs);
+    if (!inplace) {
+      FailureOr<Value> allocOrFailure =
+          state.createAlloc(rewriter, loc, extractSliceOp.result(),
+                            state.getOptions().createDeallocs);
+      if (failed(allocOrFailure))
+        return failure();
+      alloc = *allocOrFailure;
+    }
 
     // Bufferize to subview.
     auto subviewMemRefType =
@@ -238,6 +245,8 @@ struct InsertOpInterface
     auto insertOp = cast<tensor::InsertOp>(op);
     FailureOr<Value> destMemref =
         state.getResultBuffer(rewriter, insertOp->getOpResult(0));
+    if (failed(destMemref))
+      return failure();
     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
                                      *destMemref, insertOp.indices());
     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
@@ -404,6 +413,8 @@ struct InsertSliceOpInterface
     // When bufferizing out-of-place, `getResultBuffer` allocates.
     FailureOr<Value> dstMemref =
         state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
+    if (failed(dstMemref))
+      return failure();
 
     // Take a subview of the dst.
     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 3c8d6a9c96e5..58013323cb70 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -100,6 +100,8 @@ struct TransferWriteOpInterface
     // this point.
     FailureOr<Value> resultBuffer =
         state.getResultBuffer(rewriter, op->getResult(0));
+    if (failed(resultBuffer))
+      return failure();
     rewriter.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 13e18001d82e..21d7c4e62a45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -64,9 +64,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
   (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
 }
 
-static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
-                                               MemRefType type,
-                                               ArrayRef<Value> dynShape) {
+static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
+                                                MemRefType type,
+                                                ArrayRef<Value> dynShape) {
   Value allocated = b.create<memref::AllocaOp>(
       loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;


        


More information about the Mlir-commits mailing list