[Mlir-commits] [mlir] d581c94 - [mlir][linalg][bufferize] Support tensor.from_elements
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 25 05:21:30 PST 2022
Author: Matthias Springer
Date: 2022-01-25T22:19:59+09:00
New Revision: d581c94d6bfbb336b8620ef06e4340b5ea18a23e
URL: https://github.com/llvm/llvm-project/commit/d581c94d6bfbb336b8620ef06e4340b5ea18a23e
DIFF: https://github.com/llvm/llvm-project/commit/d581c94d6bfbb336b8620ef06e4340b5ea18a23e.diff
LOG: [mlir][linalg][bufferize] Support tensor.from_elements
This is mostly a copy of the existing tensor.from_elements bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted.
Differential Revision: https://reviews.llvm.org/D117775
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index aaa89b4ae2242..1c1226b451688 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -229,6 +229,82 @@ struct ExtractOpInterface
}
};
+// Implements backtracking to traverse indices of the output buffer while
+// iterating over op.elements().
+static void createStores(RewriterBase &rewriter, Location loc, int dim,
+ Value buffer, ArrayRef<int64_t> shape,
+ ArrayRef<Value> constants,
+ OperandRange::iterator &elementIt,
+ SmallVectorImpl<Value> &indices) {
+ if (dim == static_cast<int>(shape.size()) - 1) {
+ for (int i = 0; i < shape.back(); ++i) {
+ indices.back() = constants[i];
+ rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
+ ++elementIt;
+ }
+ return;
+ }
+ for (int i = 0; i < shape[dim]; ++i) {
+ indices[dim] = constants[i];
+ createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
+ indices);
+ }
+}
+
+/// Bufferization of tensor.from_elements.
+struct FromElementsOpInterface
+ : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
+ tensor::FromElementsOp> {
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto fromElementsOp = cast<tensor::FromElementsOp>(op);
+
+ // Allocate a buffer for the result.
+ Location loc = op->getLoc();
+ auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+ auto shape = tensorType.getShape();
+ MemRefType resultType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ FailureOr<Value> maybeBuffer =
+ createAlloc(rewriter, loc, resultType, {},
+ /*deallocMemref=*/state.getOptions().createDeallocs,
+ state.getOptions());
+ if (failed(maybeBuffer))
+ return failure();
+ Value buffer = *maybeBuffer;
+
+ // Case: tensor<0xelem_type>.
+ if (fromElementsOp.elements().empty()) {
+ replaceOpWithBufferizedValues(rewriter, op, buffer);
+ return success();
+ }
+
+ // Case: tensor<elem_type>.
+ if (shape.empty()) {
+ rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
+ buffer);
+ replaceOpWithBufferizedValues(rewriter, op, buffer);
+ return success();
+ }
+
+ // Create constants for the range of possible indices [0, max{shape_i}).
+ auto maxDim = *std::max_element(shape.begin(), shape.end());
+ SmallVector<Value, 2> constants;
+ constants.reserve(maxDim);
+ for (int i = 0; i < maxDim; ++i)
+ constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+ // Traverse all `elements` and create `memref.store` ops.
+ auto elementIt = fromElementsOp.elements().begin();
+ SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+ createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
+ indices);
+
+ replaceOpWithBufferizedValues(rewriter, op, buffer);
+ return success();
+ }
+};
+
/// Bufferization of tensor.generate.
struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
@@ -562,6 +638,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addOpInterface<DimOp, DimOpInterface>();
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
+ registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
registry.addOpInterface<InsertOp, InsertOpInterface>();
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index eacb2bb9314dc..c4ea9a48b8ece 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1379,3 +1379,24 @@ func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
// CHECK: }
return %result : tensor<16x?xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_from_elements_2d(
+// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
+func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
+ // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
+ // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
+ // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
+ // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
+ // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
+ // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
+ %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
+ : tensor<3x2xindex>
+ // CHECK: return %[[MEMREF]]
+ return %0 : tensor<3x2xindex>
+}
More information about the Mlir-commits
mailing list