[Mlir-commits] [mlir] 71bbb78 - [mlir][linalg][bufferize] Support tensor.generate
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 25 05:19:37 PST 2022
Author: Matthias Springer
Date: 2022-01-25T22:19:22+09:00
New Revision: 71bbb78b8fdc72732e3c21ee6d37f3c3868a7fdc
URL: https://github.com/llvm/llvm-project/commit/71bbb78b8fdc72732e3c21ee6d37f3c3868a7fdc
DIFF: https://github.com/llvm/llvm-project/commit/71bbb78b8fdc72732e3c21ee6d37f3c3868a7fdc.diff
LOG: [mlir][linalg][bufferize] Support tensor.generate
This is mostly a copy of the existing tensor.generate bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted.
Differential Revision: https://reviews.llvm.org/D117770
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ea9d885736f90..aaa89b4ae2242 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -228,6 +229,65 @@ struct ExtractOpInterface
}
};
+/// Bufferization of tensor.generate.
+struct GenerateOpInterface
+ : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
+ tensor::GenerateOp> {
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto generateOp = cast<tensor::GenerateOp>(op);
+
+ // Allocate memory.
+ Location loc = op->getLoc();
+ MemRefType memrefType =
+ getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
+ FailureOr<Value> maybeResult =
+ createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
+ /*deallocMemref=*/state.getOptions().createDeallocs,
+ state.getOptions());
+ if (failed(maybeResult))
+ return failure();
+ Value result = *maybeResult;
+
+ // Collect loop bounds.
+ int64_t rank = memrefType.getRank();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Value, 4> lowerBounds(rank, zero);
+ SmallVector<Value, 4> steps(rank, one);
+ SmallVector<Value, 4> upperBounds;
+ int nextDynamicIndex = 0;
+ for (int i = 0; i < rank; i++) {
+ Value upperBound = memrefType.isDynamicDim(i)
+ ? generateOp.dynamicExtents()[nextDynamicIndex++]
+ : rewriter.create<arith::ConstantIndexOp>(
+ loc, memrefType.getDimSize(i));
+ upperBounds.push_back(upperBound);
+ }
+
+ // Generate tensor elements with a parallel loop that stores into
+ // each element of the resulting memref. We use mergeBlockBefore to "move"
+ // this op's body into the scf.parallel's body.
+ auto parallel =
+ rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
+ Block *parallelBody = parallel.getBody();
+ rewriter.mergeBlockBefore(generateOp.getBody(),
+ parallelBody->getTerminator(),
+ parallelBody->getArguments());
+ // Replace the inlined yield op with a store op. The scf.parallel's builder
+ // already populated an scf.yield at the end, so we don't need to worry
+ // about creating that.
+ Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
+ rewriter.setInsertionPointAfter(elementYield);
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ elementYield, elementYield->getOperands()[0], result,
+ parallelBody->getArguments());
+
+ replaceOpWithBufferizedValues(rewriter, op, result);
+ return success();
+ }
+};
+
/// Bufferization of tensor.insert. Replace with memref.store.
struct InsertOpInterface
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
@@ -502,6 +562,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addOpInterface<DimOp, DimOpInterface>();
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
+ registry.addOpInterface<GenerateOp, GenerateOpInterface>();
registry.addOpInterface<InsertOp, InsertOpInterface>();
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
registry.addOpInterface<RankOp, RankOpInterface>();
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 28ee8bea2e9ec..eacb2bb9314dc 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1359,3 +1359,23 @@ func @tensor_rank(%arg0: tensor<*xf32>) -> index {
// CHECK: return %[[r]] : index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
+// CHECK-SAME: %[[arg0:.*]]: index
+func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
+ // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
+ %result = tensor.generate %arg0 {
+ ^bb0(%i: index, %j: index):
+ %sum = arith.addi %i, %j : index
+ // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
+ // CHECK: scf.yield
+ tensor.yield %sum : index
+ } : tensor<16x?xindex>
+ // CHECK: }
+ return %result : tensor<16x?xindex>
+}
More information about the Mlir-commits
mailing list