[Mlir-commits] [mlir] 9dbb8ee - [mlir][tensor] Implement getBufferType for ReshapeOp.
Ingo Müller
llvmlistbot at llvm.org
Fri Jun 2 06:40:53 PDT 2023
Author: Ingo Müller
Date: 2023-06-02T13:40:48Z
New Revision: 9dbb8eefd43b7acdd5d6d030deed839eb11bbd9c
URL: https://github.com/llvm/llvm-project/commit/9dbb8eefd43b7acdd5d6d030deed839eb11bbd9c
DIFF: https://github.com/llvm/llvm-project/commit/9dbb8eefd43b7acdd5d6d030deed839eb11bbd9c.diff
LOG: [mlir][tensor] Implement getBufferType for ReshapeOp.
This function should be implemented for ops that work in one-shot
bufferization.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D151548
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 1a4fc3bb5bbad..e824c73a1f079 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -992,13 +992,28 @@ struct ReshapeOpInterface
getBuffer(rewriter, reshapeOp.getShape(), options);
if (failed(srcBuffer) || failed(shapeBuffer))
return failure();
- auto resultMemRefType = getMemRefTypeWithStaticIdentityLayout(
- reshapeOp.getResult().getType(),
- cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
+ auto maybeResultMemRefType =
+ bufferization::getBufferType(reshapeOp.getResult(), options);
+ if (failed(maybeResultMemRefType))
+ return failure();
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
- rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
+ rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
return success();
}
+
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ assert(value == reshapeOp.getResult() && "unexpected value provided");
+ auto maybeSourceBufferType = bufferization::getBufferType(
+ reshapeOp.getSource(), options, fixedTypes);
+ if (failed(maybeSourceBufferType))
+ return failure();
+ return getMemRefTypeWithStaticIdentityLayout(
+ reshapeOp.getResult().getType(),
+ cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
+ }
};
/// Analysis of ParallelInsertSliceOp.
More information about the Mlir-commits
mailing list