[Mlir-commits] [mlir] 45b995c - [mlir][bufferize][NFC] Change signature of allocateTensorForShapedValue

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 07:03:37 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T16:00:06+02:00
New Revision: 45b995cda4611d6c0f1b4c2e8b7903303ce5d49c

URL: https://github.com/llvm/llvm-project/commit/45b995cda4611d6c0f1b4c2e8b7903303ce5d49c
DIFF: https://github.com/llvm/llvm-project/commit/45b995cda4611d6c0f1b4c2e8b7903303ce5d49c.diff

LOG: [mlir][bufferize][NFC] Change signature of allocateTensorForShapedValue

Add a failure return value and bufferization options argument. This is to keep a subsequent change smaller.

Differential Revision: https://reviews.llvm.org/D128278

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index b609a7fd78fb..ff8db00f7644 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -472,9 +472,10 @@ class AnalysisState {
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
-Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
-                                   Value shapedValue, bool escape,
-                                   bool copy = true);
+FailureOr<Value>
+allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
+                             bool escape, const BufferizationOptions &options,
+                             bool copy = true);
 
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 7e5ccd031abe..6073c931e53a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -46,9 +46,9 @@ constexpr const ::llvm::StringLiteral
 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
 /// shaped value is copied. Otherwise, a tensor with undefined contents is
 /// allocated.
-Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
-                                                  Value shapedValue,
-                                                  bool escape, bool copy) {
+FailureOr<Value> bufferization::allocateTensorForShapedValue(
+    OpBuilder &b, Location loc, Value shapedValue, bool escape,
+    const BufferizationOptions &options, bool copy) {
   Value tensor;
   if (shapedValue.getType().isa<RankedTensorType>()) {
     tensor = shapedValue;
@@ -88,7 +88,7 @@ Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
                                                copy ? tensor : Value());
   allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
                          b.getBoolArrayAttr({escape}));
-  return allocTensorOp;
+  return allocTensorOp.getResult();
 }
 
 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
@@ -147,26 +147,30 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   // Insert copies of OpOperands.
   rewriter.setInsertionPoint(op);
   for (OpOperand *opOperand : outOfPlaceOpOperands) {
-    Value copy = allocateTensorForShapedValue(
+    FailureOr<Value> copy = allocateTensorForShapedValue(
         rewriter, op->getLoc(), opOperand->get(),
-        escapingOpOperandCopies.contains(opOperand),
+        escapingOpOperandCopies.contains(opOperand), state.getOptions(),
         copiedOpOperands.contains(opOperand));
-    rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
+    if (failed(copy))
+      return failure();
+    rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
   }
 
   // Insert copies of OpResults.
   rewriter.setInsertionPointAfter(op);
   for (OpResult opResult : outOfPlaceOpResults) {
-    Value copy =
-        allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
-                                     escapingOpResultCopies.contains(opResult),
-                                     copiedOpResults.count(opResult));
+    FailureOr<Value> copy = allocateTensorForShapedValue(
+        rewriter, op->getLoc(), opResult,
+        escapingOpResultCopies.contains(opResult), state.getOptions(),
+        copiedOpResults.count(opResult));
+    if (failed(copy))
+      return failure();
     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
         opResult.getUses(), [](OpOperand &use) { return &use; }));
     for (OpOperand *use : uses) {
       // Do not update the alloc_tensor op that we just created.
-      if (use->getOwner() != copy.getDefiningOp())
-        rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
+      if (use->getOwner() != copy->getDefiningOp())
+        rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
     }
   }
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 36d0f0cefbae..0bf4fd3405dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -458,9 +458,12 @@ struct ForOpInterface
         yieldValues.push_back(value);
         continue;
       }
-      Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
-                                                 value, /*escape=*/true);
-      yieldValues.push_back(alloc);
+      FailureOr<Value> alloc =
+          allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
+                                       /*escape=*/true, state.getOptions());
+      if (failed(alloc))
+        return failure();
+      yieldValues.push_back(*alloc);
     }
 
     rewriter.updateRootInPlace(
@@ -669,9 +672,12 @@ struct WhileOpInterface
         beforeYieldValues.push_back(value);
         continue;
       }
-      Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
-                                                 value, /*escape=*/true);
-      beforeYieldValues.push_back(alloc);
+      FailureOr<Value> alloc =
+          allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
+                                       /*escape=*/true, state.getOptions());
+      if (failed(alloc))
+        return failure();
+      beforeYieldValues.push_back(*alloc);
     }
     rewriter.updateRootInPlace(conditionOp, [&]() {
       conditionOp.getArgsMutable().assign(beforeYieldValues);
@@ -687,9 +693,12 @@ struct WhileOpInterface
         afterYieldValues.push_back(value);
         continue;
       }
-      Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
-                                                 value, /*escape=*/true);
-      afterYieldValues.push_back(alloc);
+      FailureOr<Value> alloc =
+          allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
+                                       /*escape=*/true, state.getOptions());
+      if (failed(alloc))
+        return failure();
+      afterYieldValues.push_back(*alloc);
     }
     rewriter.updateRootInPlace(yieldOp, [&]() {
       yieldOp.getResultsMutable().assign(afterYieldValues);
@@ -972,13 +981,15 @@ struct ForeachThreadOpInterface
 
       // Insert tensor allocation.
       bool isYielded = state.isTensorYielded(opResult);
-      Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
-                                                 destOperands.front()->get(),
-                                                 /*escape=*/isYielded);
+      FailureOr<Value> alloc = allocateTensorForShapedValue(
+          rewriter, op->getLoc(), destOperands.front()->get(),
+          /*escape=*/isYielded, state.getOptions());
+      if (failed(alloc))
+        return failure();
 
       // Update terminator operand.
       rewriter.updateRootInPlace(destOperands.front()->getOwner(),
-                                 [&]() { destOperands.front()->set(alloc); });
+                                 [&]() { destOperands.front()->set(*alloc); });
     }
 
     return success();

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e7e31dcd42f5..6f2484352210 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -154,15 +154,17 @@ struct CollapseShapeOpInterface
     if (!canBeCollapsed) {
       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
       AnalysisState analysisState(options);
-      Value tensorAlloc = allocateTensorForShapedValue(
+      FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
-          analysisState.isTensorYielded(collapseShapeOp.getResult()));
+          analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
+      if (failed(tensorAlloc))
+        return failure();
       auto memrefType =
           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
                           collapseShapeOp.getSrcType().getElementType(),
                           AffineMap(), bufferType.getMemorySpaceAsInt());
       buffer = rewriter.create<bufferization::ToMemrefOp>(
-          op->getLoc(), memrefType, tensorAlloc);
+          op->getLoc(), memrefType, *tensorAlloc);
     }
 
     // Result type is inferred by the builder.
@@ -383,14 +385,16 @@ struct FromElementsOpInterface
     auto shape = tensorType.getShape();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
     AnalysisState analysisState(options);
-    Value tensorAlloc = allocateTensorForShapedValue(
+    FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
         rewriter, loc, fromElementsOp.getResult(),
-        analysisState.isTensorYielded(fromElementsOp.getResult()),
+        analysisState.isTensorYielded(fromElementsOp.getResult()), options,
         /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
     auto memrefType =
         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, tensorAlloc);
+        op->getLoc(), memrefType, *tensorAlloc);
 
     // Case: tensor<0xelem_type>.
     if (fromElementsOp.getElements().empty()) {
@@ -436,14 +440,16 @@ struct GenerateOpInterface
     Location loc = op->getLoc();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
     AnalysisState analysisState(options);
-    Value tensorAlloc = allocateTensorForShapedValue(
+    FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
         rewriter, loc, generateOp.getResult(),
-        analysisState.isTensorYielded(generateOp.getResult()),
+        analysisState.isTensorYielded(generateOp.getResult()), options,
         /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
     auto memrefType =
         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, tensorAlloc);
+        op->getLoc(), memrefType, *tensorAlloc);
 
     // Collect loop bounds.
     int64_t rank = memrefType.getRank();


        


More information about the Mlir-commits mailing list