[Mlir-commits] [mlir] 53ff0da - [mlir] Fail early if AnalysisState::getBuffer() returns failure

Ashay Rane llvmlistbot at llvm.org
Tue May 10 08:08:51 PDT 2022


Author: Ashay Rane
Date: 2022-05-10T08:08:38-07:00
New Revision: 53ff0daa7e9d3646ac9de7f0d6ed39359af94738

URL: https://github.com/llvm/llvm-project/commit/53ff0daa7e9d3646ac9de7f0d6ed39359af94738
DIFF: https://github.com/llvm/llvm-project/commit/53ff0daa7e9d3646ac9de7f0d6ed39359af94738.diff

LOG: [mlir] Fail early if AnalysisState::getBuffer() returns failure

This patch updates calls to AnalysisState::getBuffer() so that we return
early with a failure if the call does not succeed.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D125251

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index b00d87ba54034..efd2de7978e94 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -115,7 +115,9 @@ struct CollapseShapeOpInterface
 
     if (tensorResultType.getRank() == 0) {
       // 0-d collapses must go through a 
diff erent op builder.
-      Value buffer = *state.getBuffer(rewriter, srcOperand);
+      auto buffer = state.getBuffer(rewriter, srcOperand);
+      if (failed(buffer))
+        return failure();
       MemRefType resultType;
 
       if (bufferType.getLayout().isIdentity()) {
@@ -138,7 +140,7 @@ struct CollapseShapeOpInterface
       }
 
       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
-          rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+          rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
       return success();
     }
 
@@ -152,11 +154,13 @@ struct CollapseShapeOpInterface
             ? None
             : Optional<BufferizationState::ForceInPlacability>(
                   BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
-    Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
+    auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
+    if (failed(buffer))
+      return failure();
 
     // Result type is inferred by the builder.
     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
-        rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
+        rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
     return success();
   }
 };
@@ -183,8 +187,11 @@ struct DimOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
-    replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
+    auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
+    if (failed(v))
+      return failure();
+    replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
+                                                dimOp.index());
     return success();
   }
 };
@@ -219,13 +226,15 @@ struct ExpandShapeOpInterface
                           BufferizationState &state) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto tensorResultType = expandShapeOp.getResultType();
-    Value buffer =
-        *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
+    auto buffer =
+        state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
+    if (failed(buffer))
+      return failure();
 
     // Memref result type is inferred by the builder based on reassociation
     // indices and result shape.
     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, tensorResultType.getShape(), buffer,
+        rewriter, op, tensorResultType.getShape(), *buffer,
         expandShapeOp.getReassociationIndices());
     return success();
   }
@@ -264,10 +273,12 @@ struct ExtractSliceOpInterface
 
     // Even if this op was decided to bufferize out-of-place, do not insert the
     // buffer copy yet. This is done later in this function.
-    Value srcMemref =
-        *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
-                         BufferizationState::ForceInPlacability::FORCE_INPLACE);
-    auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+    auto srcMemref =
+        state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
+                        BufferizationState::ForceInPlacability::FORCE_INPLACE);
+    if (failed(srcMemref))
+      return failure();
+    auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();
 
@@ -289,7 +300,7 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     OffsetSizeAndStrideOpInterface::expandToRank(
-        srcMemref, mixedOffsets, mixedSizes, mixedStrides,
+        *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
         [&](Value target, int64_t dim) -> OpFoldResult {
           auto shapedType = target.getType().cast<ShapedType>();
           if (shapedType.isDynamicDim(dim))
@@ -302,7 +313,7 @@ struct ExtractSliceOpInterface
                                  mixedOffsets, mixedSizes, mixedStrides)
                                  .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
+        loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
         mixedStrides);
 
     // If not inplaceable, copy.
@@ -342,9 +353,11 @@ struct ExtractOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-    Value srcMemref =
-        *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
-    replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
+    auto srcMemref =
+        state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
+    if (failed(srcMemref))
+      return failure();
+    replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
                                                  extractOp.indices());
     return success();
   }
@@ -703,10 +716,10 @@ struct InsertSliceOpInterface
 
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.
-    Value srcMemref =
-        *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
-    if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
-                            state.getOptions())))
+    auto srcMemref =
+        state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
+    if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
+                                                 subView, state.getOptions())))
       return failure();
 
     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
@@ -736,9 +749,11 @@ struct RankOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto rankOp = cast<tensor::RankOp>(op);
-    Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
+    auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
+    if (failed(v))
+      return failure();
     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
-                                                 v);
+                                                 *v);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list