[Mlir-commits] [mlir] 65ef43e - [mlir][linalg][bufferize][NFC] Check return value of getResultBuffer

Matthias Springer llvmlistbot at llvm.org
Thu Oct 21 01:27:56 PDT 2021


Author: Matthias Springer
Date: 2021-10-21T17:24:17+09:00
New Revision: 65ef43e288ad1e9fa7a01d2c09a13727f568b870

URL: https://github.com/llvm/llvm-project/commit/65ef43e288ad1e9fa7a01d2c09a13727f568b870
DIFF: https://github.com/llvm/llvm-project/commit/65ef43e288ad1e9fa7a01d2c09a13727f568b870.diff

LOG: [mlir][linalg][bufferize][NFC] Check return value of getResultBuffer

In a subsequent commit, getResultBuffer can return a "null" Value. This is the case when the returned buffer from an scf.if is not unique.

This commit is in preparation for scf.if support to keep the next commit smaller.

Differential Revision: https://reviews.llvm.org/D111927

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index e0e4105870b4..be18cd135a47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1422,10 +1422,11 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
 /// Helper function for LinalgOp bufferization.
 /// When allocating a new buffer, analyze whether `op` wants to read form that
 /// buffer. Only in that case, a copy of the result buffer may be needed.
-static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
-                                      SmallVectorImpl<Value> &resultBuffers,
-                                      BlockAndValueMapping &bvm,
-                                      BufferizationAliasInfo &aliasInfo) {
+static LogicalResult
+allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+                          SmallVectorImpl<Value> &resultBuffers,
+                          BlockAndValueMapping &bvm,
+                          BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -1437,11 +1438,15 @@ static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
     assert(opResult && "could not find correspond OpResult");
     bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
     Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
+    if (!resultBuffer)
+      return failure();
     resultBuffers.push_back(resultBuffer);
   }
 
   if (op->getNumResults())
     map(bvm, op->getResults(), resultBuffers);
+
+  return success();
 }
 
 /// Generic conversion for any LinalgOp on tensors.
@@ -1469,7 +1474,9 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   }
   SmallVector<Value> newOutputBuffers;
   // Try to allocate new buffers depending on op's inplace semantics.
-  allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo);
+  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
+                                       aliasInfo)))
+    return failure();
 
   // Clone the newly bufferized op.
   SmallVector<Value> newOperands = newInputBuffers;
@@ -1638,6 +1645,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
   b.setInsertionPoint(castOp);
 
   Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
+  if (!resultBuffer)
+    return failure();
   Type sourceType = resultBuffer.getType();
   auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
   auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
@@ -1713,6 +1722,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
 
     // TODO: More general: Matching bbArg does not bufferize to a read.
     Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+    if (!resultBuffer)
+      return failure();
 
     OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
     BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
@@ -1834,6 +1845,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
     const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
     OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
     Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+    if (!resultBuffer)
+      return failure();
 
     // Insert mapping and aliasing info.
     aliasInfo.createAliasInfoEntry(resultBuffer);
@@ -1982,6 +1995,9 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
   // buffer.
   Value dstMemref =
       getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
+  if (!dstMemref)
+    return failure();
+
   auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
 
   Value srcMemref = lookup(bvm, insertSliceOp.source());
@@ -2044,6 +2060,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   // this point.
   auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
   Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
+  if (!resultBuffer)
+    return failure();
   b.create<vector::TransferWriteOp>(
       op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
       writeOp.permutation_map(),


        


More information about the Mlir-commits mailing list