[Mlir-commits] [mlir] fc08d1c - [mlir][tensor][bufferize] Support tensor.rank in BufferizableOpInterfaceImpl
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 24 07:31:54 PST 2022
Author: Matthias Springer
Date: 2022-01-25T00:31:20+09:00
New Revision: fc08d1c2940609d26a534d7a12e6c6a528891830
URL: https://github.com/llvm/llvm-project/commit/fc08d1c2940609d26a534d7a12e6c6a528891830
DIFF: https://github.com/llvm/llvm-project/commit/fc08d1c2940609d26a534d7a12e6c6a528891830.diff
LOG: [mlir][tensor][bufferize] Support tensor.rank in BufferizableOpInterfaceImpl
This is the only op that is not supported via BufferizableOpInterfaceImpl bufferization. Once this op is supported we can switch `tensor-bufferize` over to the new unified bufferization.
Differential Revision: https://reviews.llvm.org/D117985
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index f4a2a5d692152..0fe79862a69d0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -457,16 +457,22 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc());
OpBuilder b(funcOp->getContext());
b.setInsertionPointToStart(&frontBlock);
- // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
+ // Replace all uses of bbArg through a ToMemRefOp.
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
if (auto toMemrefOp =
dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
- assert(memref::CastOp::areCastCompatible(
- memref.getType(), toMemrefOp.memref().getType()) &&
- "bufferizeFuncOpBoundary: cast incompatible");
- auto castOp = b.create<memref::CastOp>(
- funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
- toMemrefOp.memref().replaceAllUsesWith(castOp);
+ if (memref.getType() != toMemrefOp.memref().getType()) {
+ // Type has changed, insert a cast.
+ assert(memref::CastOp::areCastCompatible(
+ memref.getType(), toMemrefOp.memref().getType()) &&
+ "bufferizeFuncOpBoundary: cast incompatible");
+ auto castOp = b.create<memref::CastOp>(
+ funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
+ toMemrefOp.memref().replaceAllUsesWith(castOp);
+ } else {
+ // Type did not change, replace directly.
+ toMemrefOp.memref().replaceAllUsesWith(memref);
+ }
}
}
// Replace all remaining uses by a to_tensor.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 03fa45f04c68f..ea9d885736f90 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -463,6 +463,35 @@ struct InsertSliceOpInterface
}
};
+/// Bufferization of tensor.rank. Replace with memref.rank.
+struct RankOpInterface
+ : public BufferizableOpInterface::ExternalModel<RankOpInterface,
+ tensor::RankOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return OpResult();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto rankOp = cast<tensor::RankOp>(op);
+ Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
+ replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
+ v);
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -475,4 +504,5 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
registry.addOpInterface<InsertOp, InsertOpInterface>();
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
+ registry.addOpInterface<RankOp, RankOpInterface>();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index a739fc4645ed0..1f301a14c11e1 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1348,3 +1348,14 @@ func @write_after_select_read_one(
// CHECK: return %[[f]], %[[select]]
return %f, %w : f32, tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_rank(
+// CHECK-SAME: %[[arg0:.*]]: memref<*xf32>
+func @tensor_rank(%arg0: tensor<*xf32>) -> index {
+ // CHECK: %[[r:.*]] = memref.rank %[[arg0]]
+ %0 = tensor.rank %arg0 : tensor<*xf32>
+ // CHECK: return %[[r]] : index
+ return %0 : index
+}
More information about the Mlir-commits
mailing list