[Mlir-commits] [mlir] 481b254 - [mlir][tensor][bufferize] Bufferize tensor.splat op
Matthias Springer
llvmlistbot at llvm.org
Mon May 22 05:31:52 PDT 2023
Author: Matthias Springer
Date: 2023-05-22T14:31:39+02:00
New Revision: 481b254e458bc195af16fef9625cf856ef87fced
URL: https://github.com/llvm/llvm-project/commit/481b254e458bc195af16fef9625cf856ef87fced
DIFF: https://github.com/llvm/llvm-project/commit/481b254e458bc195af16fef9625cf856ef87fced.diff
LOG: [mlir][tensor][bufferize] Bufferize tensor.splat op
The op bufferizes similarly to tensor.generate: it is lowered to a linalg.map, which may then lower to a loop nest that fills the buffer.
Differential Revision: https://reviews.llvm.org/D150952
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9253bc2ffeb7e..935a1b9fededf 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1087,6 +1087,54 @@ struct ParallelInsertSliceOpInterface
}
};
+/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
+/// with a linalg.map. Similar to tensor.generate.
+struct SplatOpInterface
+ : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
+ tensor::SplatOp> {
+
+ bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
+ return true;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto splatOp = cast<tensor::SplatOp>(op);
+
+ // Should the buffer be deallocated?
+ bool dealloc =
+ shouldDeallocateOpResult(cast<OpResult>(splatOp.getResult()), options);
+
+ // TODO: Implement memory space for this op.
+ if (options.defaultMemorySpace != Attribute())
+ return op->emitError("memory space not implemented yet");
+
+ // Allocate memory.
+ Location loc = op->getLoc();
+ FailureOr<Value> tensorAlloc =
+ allocateTensorForShapedValue(rewriter, loc, splatOp.getResult(),
+ /*escape=*/!dealloc, options,
+ /*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
+
+ // Create linalg::MapOp.
+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+ auto linalgOp =
+ rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
+ /*init=*/*tensorAlloc);
+ Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+
+ // Create linalg::IndexOps.
+ rewriter.setInsertionPointToStart(&linalgBody);
+ rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
+ rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
+
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1110,6 +1158,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
*ctx);
RankOp::attachInterface<RankOpInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
+ SplatOp::attachInterface<SplatOpInterface>(*ctx);
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index fe665a32d709b..b9382b9844df1 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -582,3 +582,20 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
// CHECK: return %[[r]] : tensor<?x?xindex>
return %0 : tensor<?x?xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @tensor.splat(
+// CHECK-SAME: %[[F:.*]]: f32)
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32>
+// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[MAPPED:.*]] = linalg.map
+// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4xf32>)
+// CHECK: linalg.yield %[[F]]
+// CHECK: }
+// CHECK: return %[[MAPPED]] : tensor<10x2x4xf32>
+// CHECK: }
+func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
+ %t = tensor.splat %f : tensor<10x2x4xf32>
+ return %t : tensor<10x2x4xf32>
+}
More information about the Mlir-commits
mailing list