[Mlir-commits] [mlir] 53ff0da - [mlir] Fail early if AnalysisState::getBuffer() returns failure
Ashay Rane
llvmlistbot at llvm.org
Tue May 10 08:08:51 PDT 2022
Author: Ashay Rane
Date: 2022-05-10T08:08:38-07:00
New Revision: 53ff0daa7e9d3646ac9de7f0d6ed39359af94738
URL: https://github.com/llvm/llvm-project/commit/53ff0daa7e9d3646ac9de7f0d6ed39359af94738
DIFF: https://github.com/llvm/llvm-project/commit/53ff0daa7e9d3646ac9de7f0d6ed39359af94738.diff
LOG: [mlir] Fail early if AnalysisState::getBuffer() returns failure
This patch updates calls to AnalysisState::getBuffer() so that we return
early with a failure if the call does not succeed.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D125251
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index b00d87ba54034..efd2de7978e94 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -115,7 +115,9 @@ struct CollapseShapeOpInterface
if (tensorResultType.getRank() == 0) {
// 0-d collapses must go through a
diff erent op builder.
- Value buffer = *state.getBuffer(rewriter, srcOperand);
+ auto buffer = state.getBuffer(rewriter, srcOperand);
+ if (failed(buffer))
+ return failure();
MemRefType resultType;
if (bufferType.getLayout().isIdentity()) {
@@ -138,7 +140,7 @@ struct CollapseShapeOpInterface
}
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
- rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+ rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
return success();
}
@@ -152,11 +154,13 @@ struct CollapseShapeOpInterface
? None
: Optional<BufferizationState::ForceInPlacability>(
BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
- Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
+ auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
+ if (failed(buffer))
+ return failure();
// Result type is inferred by the builder.
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
- rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
+ rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
return success();
}
};
@@ -183,8 +187,11 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
- Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
- replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
+ auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
+ if (failed(v))
+ return failure();
+ replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
+ dimOp.index());
return success();
}
};
@@ -219,13 +226,15 @@ struct ExpandShapeOpInterface
BufferizationState &state) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
- Value buffer =
- *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
+ auto buffer =
+ state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
+ if (failed(buffer))
+ return failure();
// Memref result type is inferred by the builder based on reassociation
// indices and result shape.
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), buffer,
+ rewriter, op, tensorResultType.getShape(), *buffer,
expandShapeOp.getReassociationIndices());
return success();
}
@@ -264,10 +273,12 @@ struct ExtractSliceOpInterface
// Even if this op was decided to bufferize out-of-place, do not insert the
// buffer copy yet. This is done later in this function.
- Value srcMemref =
- *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
- BufferizationState::ForceInPlacability::FORCE_INPLACE);
- auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+ auto srcMemref =
+ state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
+ BufferizationState::ForceInPlacability::FORCE_INPLACE);
+ if (failed(srcMemref))
+ return failure();
+ auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
@@ -289,7 +300,7 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
- srcMemref, mixedOffsets, mixedSizes, mixedStrides,
+ *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
@@ -302,7 +313,7 @@ struct ExtractSliceOpInterface
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
+ loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
mixedStrides);
// If not inplaceable, copy.
@@ -342,9 +353,11 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
- Value srcMemref =
- *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
- replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
+ auto srcMemref =
+ state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
+ if (failed(srcMemref))
+ return failure();
+ replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
extractOp.indices());
return success();
}
@@ -703,10 +716,10 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
- Value srcMemref =
- *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
- if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
- state.getOptions())))
+ auto srcMemref =
+ state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
+ if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
+ subView, state.getOptions())))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
@@ -736,9 +749,11 @@ struct RankOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto rankOp = cast<tensor::RankOp>(op);
- Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
+ auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
+ if (failed(v))
+ return failure();
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
- v);
+ *v);
return success();
}
};
More information about the Mlir-commits
mailing list