[Mlir-commits] [mlir] 481b254 - [mlir][tensor][bufferize] Bufferize tensor.splat op

Matthias Springer llvmlistbot at llvm.org
Mon May 22 05:31:52 PDT 2023

Author: Matthias Springer
Date: 2023-05-22T14:31:39+02:00
New Revision: 481b254e458bc195af16fef9625cf856ef87fced

URL: https://github.com/llvm/llvm-project/commit/481b254e458bc195af16fef9625cf856ef87fced
DIFF: https://github.com/llvm/llvm-project/commit/481b254e458bc195af16fef9625cf856ef87fced.diff

LOG: [mlir][tensor][bufferize] Bufferize tensor.splat op

The op bufferizes similarly to tensor.generate: it is lowered to a linalg.map, which may then lower to a loop nest that fills the buffer.

Differential Revision: https://reviews.llvm.org/D150952




diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9253bc2ffeb7e..935a1b9fededf 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1087,6 +1087,54 @@ struct ParallelInsertSliceOpInterface
+/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
+/// with a linalg.map. Similar to tensor.generate.
+struct SplatOpInterface
+    : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
+                                                    tensor::SplatOp> {
+  bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
+    return true;
+  }
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto splatOp = cast<tensor::SplatOp>(op);
+    // Should the buffer be deallocated?
+    bool dealloc =
+        shouldDeallocateOpResult(cast<OpResult>(splatOp.getResult()), options);
+    // TODO: Implement memory space for this op.
+    if (options.defaultMemorySpace != Attribute())
+      return op->emitError("memory space not implemented yet");
+    // Allocate memory.
+    Location loc = op->getLoc();
+    FailureOr<Value> tensorAlloc =
+        allocateTensorForShapedValue(rewriter, loc, splatOp.getResult(),
+                                     /*escape=*/!dealloc, options,
+                                     /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
+    // Create linalg::MapOp.
+    auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+    auto linalgOp =
+        rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
+                                       /*init=*/*tensorAlloc);
+    Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+    // Create linalg::IndexOps.
+    rewriter.setInsertionPointToStart(&linalgBody);
+    rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
+    rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
+    return success();
+  }
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -1110,6 +1158,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
+    SplatOp::attachInterface<SplatOpInterface>(*ctx);
     // Load additional dialects of which ops may get created.
     ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index fe665a32d709b..b9382b9844df1 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -582,3 +582,20 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   // CHECK:     return %[[r]] : tensor<?x?xindex>
   return %0 : tensor<?x?xindex>
+// -----
+// CHECK-LABEL:   func @tensor.splat(
+// CHECK-SAME:        %[[F:.*]]: f32)
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32>
+// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           %[[MAPPED:.*]] = linalg.map
+// CHECK:                 outs(%[[ALLOC_T]] : tensor<10x2x4xf32>)
+// CHECK:             linalg.yield %[[F]]
+// CHECK:           }
+// CHECK:           return %[[MAPPED]] : tensor<10x2x4xf32>
+// CHECK:         }
+func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
+  %t = tensor.splat %f : tensor<10x2x4xf32>
+  return %t : tensor<10x2x4xf32>


More information about the Mlir-commits mailing list