[Mlir-commits] [mlir] e287d64 - [mlir] Add translation from tensor.reshape to memref.reshape
Matthias Springer
llvmlistbot at llvm.org
Mon May 9 08:45:13 PDT 2022
Author: Ashay Rane
Date: 2022-05-09T17:45:07+02:00
New Revision: e287d647c61f5fbd054410d6ed9c37d5271f29ef
URL: https://github.com/llvm/llvm-project/commit/e287d647c61f5fbd054410d6ed9c37d5271f29ef
DIFF: https://github.com/llvm/llvm-project/commit/e287d647c61f5fbd054410d6ed9c37d5271f29ef.diff
LOG: [mlir] Add translation from tensor.reshape to memref.reshape
This patch augments the `tensor-bufferize` pass by adding a conversion
rule to translate ReshapeOp from the `tensor` dialect to the `memref`
dialect, in addition to adding a unit test to validate the translation.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D125031
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 5be820536dab6..b00d87ba54034 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -743,6 +743,54 @@ struct RankOpInterface
}
};
+/// Bufferization of tensor.reshape. Replace with memref.reshape.
+struct ReshapeOpInterface
+ : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
+ tensor::ReshapeOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ if (&opOperand == &op->getOpOperand(1) /* shape */)
+ return true;
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {op->getOpResult(0)};
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ BufferizationState &state) const {
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ auto &srcOperand = reshapeOp->getOpOperand(0);
+ auto srcBuffer = state.getBuffer(rewriter, srcOperand);
+ if (failed(srcBuffer))
+ return failure();
+
+ auto &shapeOperand = reshapeOp->getOpOperand(1);
+ auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
+ if (failed(shapeBuffer))
+ return failure();
+
+ auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
+ auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
+
+ replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
+ rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -761,5 +809,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
InsertOp::attachInterface<InsertOpInterface>(*ctx);
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
RankOp::attachInterface<RankOpInterface>(*ctx);
+ ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 587508c698e3f..cd88f2fd4afab 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -408,3 +408,30 @@ func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: in
%ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
return %ret: tensor<8xf32>
}
+
+// CHECK-LABEL: func @tensor.reshape(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
+
+ // CHECK: %[[two:.*]] = arith.constant 2 : i64
+ %two = arith.constant 2 : i64
+ // CHECK: %[[five:.*]] = arith.constant 5 : i64
+ %five = arith.constant 5 : i64
+
+ // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64>
+ // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index
+ // CHECK: %[[one_idx:.*]] = arith.constant 1 : index
+ // CHECK: %[[two_idx:.*]] = arith.constant 2 : index
+ // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64>
+ // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64>
+ // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64>
+ %shape = tensor.from_elements %two, %two, %five : tensor<3xi64>
+
+ // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32>
+ %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]]
+ // CHECK: return %[[r]]
+ return %reshaped : tensor<2x2x5xf32>
+}
More information about the Mlir-commits
mailing list