[Mlir-commits] [mlir] d581c94 - [mlir][linalg][bufferize] Support tensor.from_elements

Matthias Springer llvmlistbot at llvm.org
Tue Jan 25 05:21:30 PST 2022


Author: Matthias Springer
Date: 2022-01-25T22:19:59+09:00
New Revision: d581c94d6bfbb336b8620ef06e4340b5ea18a23e

URL: https://github.com/llvm/llvm-project/commit/d581c94d6bfbb336b8620ef06e4340b5ea18a23e
DIFF: https://github.com/llvm/llvm-project/commit/d581c94d6bfbb336b8620ef06e4340b5ea18a23e.diff

LOG: [mlir][linalg][bufferize] Support tensor.from_elements

This is mostly a copy of the existing tensor.from_elements bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted.

Differential Revision: https://reviews.llvm.org/D117775

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index aaa89b4ae2242..1c1226b451688 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -229,6 +229,82 @@ struct ExtractOpInterface
   }
 };
 
+// Implements backtracking to traverse indices of the output buffer while
+// iterating over op.elements().
+static void createStores(RewriterBase &rewriter, Location loc, int dim,
+                         Value buffer, ArrayRef<int64_t> shape,
+                         ArrayRef<Value> constants,
+                         OperandRange::iterator &elementIt,
+                         SmallVectorImpl<Value> &indices) {
+  if (dim == static_cast<int>(shape.size()) - 1) {
+    for (int i = 0; i < shape.back(); ++i) {
+      indices.back() = constants[i];
+      rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
+      ++elementIt;
+    }
+    return;
+  }
+  for (int i = 0; i < shape[dim]; ++i) {
+    indices[dim] = constants[i];
+    createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
+                 indices);
+  }
+}
+
+/// Bufferization of tensor.from_elements.
+struct FromElementsOpInterface
+    : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
+                                                    tensor::FromElementsOp> {
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto fromElementsOp = cast<tensor::FromElementsOp>(op);
+
+    // Allocate a buffer for the result.
+    Location loc = op->getLoc();
+    auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+    auto shape = tensorType.getShape();
+    MemRefType resultType =
+        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    FailureOr<Value> maybeBuffer =
+        createAlloc(rewriter, loc, resultType, {},
+                    /*deallocMemref=*/state.getOptions().createDeallocs,
+                    state.getOptions());
+    if (failed(maybeBuffer))
+      return failure();
+    Value buffer = *maybeBuffer;
+
+    // Case: tensor<0xelem_type>.
+    if (fromElementsOp.elements().empty()) {
+      replaceOpWithBufferizedValues(rewriter, op, buffer);
+      return success();
+    }
+
+    // Case: tensor<elem_type>.
+    if (shape.empty()) {
+      rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
+                                       buffer);
+      replaceOpWithBufferizedValues(rewriter, op, buffer);
+      return success();
+    }
+
+    // Create constants for the range of possible indices [0, max{shape_i}).
+    auto maxDim = *std::max_element(shape.begin(), shape.end());
+    SmallVector<Value, 2> constants;
+    constants.reserve(maxDim);
+    for (int i = 0; i < maxDim; ++i)
+      constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+    // Traverse all `elements` and create `memref.store` ops.
+    auto elementIt = fromElementsOp.elements().begin();
+    SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+    createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
+                 indices);
+
+    replaceOpWithBufferizedValues(rewriter, op, buffer);
+    return success();
+  }
+};
+
 /// Bufferization of tensor.generate.
 struct GenerateOpInterface
     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
@@ -562,6 +638,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
   registry.addOpInterface<DimOp, DimOpInterface>();
   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
+  registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
   registry.addOpInterface<InsertOp, InsertOpInterface>();
   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index eacb2bb9314dc..c4ea9a48b8ece 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1379,3 +1379,24 @@ func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
   // CHECK: }
   return %result : tensor<16x?xindex>
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_from_elements_2d(
+//  CHECK-SAME:     %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
+func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
+  //     CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
+  //     CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
+  //     CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
+  //     CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
+  //     CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
+  //     CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
+  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
+         : tensor<3x2xindex>
+  //     CHECK: return %[[MEMREF]]
+  return %0 : tensor<3x2xindex>
+}


        


More information about the Mlir-commits mailing list