[Mlir-commits] [mlir] e6f6916 - [mlir][bufferize] Support tensor.expand_shape and tensor.collapse_shape
Matthias Springer
llvmlistbot at llvm.org
Tue Feb 15 02:57:53 PST 2022
Author: Matthias Springer
Date: 2022-02-15T19:53:49+09:00
New Revision: e6f691615e481d74cad8c6369bc0116192630ea1
URL: https://github.com/llvm/llvm-project/commit/e6f691615e481d74cad8c6369bc0116192630ea1
DIFF: https://github.com/llvm/llvm-project/commit/e6f691615e481d74cad8c6369bc0116192630ea1.diff
LOG: [mlir][bufferize] Support tensor.expand_shape and tensor.collapse_shape
Differential Revision: https://reviews.llvm.org/D112512
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 91e10916124ab..ad5485de63317 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -80,6 +80,46 @@ struct CastOpInterface
}
};
+/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
+struct CollapseShapeOpInterface
+ : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
+ tensor::CollapseShapeOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ SmallVector<OpResult>
+ getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ if (&opOperand == &op->getOpOperand(0) /*src*/)
+ return {op->getOpResult(0)};
+ return {};
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
+ Value buffer =
+ *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
+ Type resultType =
+ getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
+ replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
+ rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+ return success();
+ }
+};
+
/// Bufferization of tensor.dim. Replace with memref.dim.
struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
@@ -109,6 +149,46 @@ struct DimOpInterface
}
};
+/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
+struct ExpandShapeOpInterface
+ : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
+ tensor::ExpandShapeOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ SmallVector<OpResult>
+ getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ if (&opOperand == &op->getOpOperand(0) /*src*/)
+ return {op->getOpResult(0)};
+ return {};
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ Value buffer =
+ *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
+ Type resultType =
+ getMemRefType(expandShapeOp.getResultType(), state.getOptions());
+ replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
+ rewriter, op, resultType, buffer, expandShapeOp.reassociation());
+ return success();
+ }
+};
+
/// Bufferization of tensor.extract_slice. Replace with memref.subview.
struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
@@ -635,7 +715,9 @@ struct RankOpInterface
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addOpInterface<CastOp, CastOpInterface>();
+ registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
registry.addOpInterface<DimOp, DimOpInterface>();
+ registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index b0415ce1464ce..7d3084d9d024c 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -301,3 +301,31 @@ func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32>
// CHECK: return %[[r]]
return %0 : tensor<5xf32>
}
+
+// CHECK-LABEL: func @tensor.expand_shape(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
+ // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
+ %0 = tensor.expand_shape %t1 [[0, 1], [2]]
+ : tensor<?x10xf32> into tensor<2x?x10xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<2x?x10xf32>
+}
+
+// CHECK-LABEL: func @tensor.collapse_shape(
+// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32>
+func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32>
+ // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [
+ // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32>
+ %0 = tensor.collapse_shape %t1 [[0, 1], [2]]
+ : tensor<2x?x?xf32> into tensor<?x?xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list