[Mlir-commits] [mlir] 124fce0 - [mlir][linalg] Convert tensor.from_elements to destination style

Matthias Springer llvmlistbot at llvm.org
Wed Jan 25 00:18:45 PST 2023

Author: Matthias Springer
Date: 2023-01-25T09:18:38+01:00
New Revision: 124fce09a24e7ca524cb47e3f1c6047fdfca739f

URL: https://github.com/llvm/llvm-project/commit/124fce09a24e7ca524cb47e3f1c6047fdfca739f
DIFF: https://github.com/llvm/llvm-project/commit/124fce09a24e7ca524cb47e3f1c6047fdfca739f.diff

LOG: [mlir][linalg] Convert tensor.from_elements to destination style

This can be a pre-processing for bufferization and allows for more efficient lowerings without an alloc.

Differential Revision: https://reviews.llvm.org/D142206




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index e2234b541fa9c..f0f7187804bbb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -24,8 +24,72 @@
 using namespace mlir;
 using namespace mlir::tensor;
+// Implements backtracking to traverse indices of the output buffer while
+// iterating over op.elements().
+static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
+                           Value destination, ArrayRef<int64_t> shape,
+                           ArrayRef<Value> constants,
+                           OperandRange::iterator &elementIt,
+                           SmallVectorImpl<Value> &indices) {
+  if (dim == static_cast<int>(shape.size()) - 1) {
+    for (int i = 0; i < shape.back(); ++i) {
+      indices.back() = constants[i];
+      destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
+                                                      destination, indices);
+      ++elementIt;
+    }
+    return destination;
+  }
+  for (int i = 0; i < shape[dim]; ++i) {
+    indices[dim] = constants[i];
+    destination = createInserts(rewriter, loc, dim + 1, destination, shape,
+                                constants, elementIt, indices);
+  }
+  return destination;
 namespace {
+/// Lower tensor.from_elements to a sequence of chained tensor.insert.
+struct FromElementsOpConverter : public OpRewritePattern<FromElementsOp> {
+  using OpRewritePattern<FromElementsOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FromElementsOp elementsOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = elementsOp.getLoc();
+    RankedTensorType tensorType = elementsOp.getType().cast<RankedTensorType>();
+    auto shape = tensorType.getShape();
+    // Create tensor.empty.
+    auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
+    // Case: tensor<elem_type>.
+    if (shape.empty()) {
+      rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+          elementsOp, elementsOp.getElements().front(), emptyOp.getResult(),
+          ValueRange());
+      return success();
+    }
+    // Create constants for the range of possible indices [0, max{shape_i}).
+    auto maxDim = *std::max_element(shape.begin(), shape.end());
+    SmallVector<Value, 2> constants;
+    constants.reserve(maxDim);
+    for (int i = 0; i < maxDim; ++i)
+      constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+    // Traverse all elements and create tensor.insert ops.
+    auto elementIt = elementsOp.getElements().begin();
+    SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+    Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
+                                 shape, constants, elementIt, indices);
+    // Replace tensor.from_elements.
+    rewriter.replaceOp(elementsOp, result);
+    return success();
+  }
 /// Lower tensor.generate to linalg.generic.
 struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
   using OpRewritePattern<GenerateOp>::OpRewritePattern;
@@ -172,5 +236,6 @@ struct PadOpConverter : public OpRewritePattern<PadOp> {
 void linalg::populateConvertToDestinationStylePatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<GenerateOpConverter, PadOpConverter>(patterns.getContext());
+  patterns.insert<FromElementsOpConverter, GenerateOpConverter, PadOpConverter>(
+      patterns.getContext());

diff  --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
index 4a29865e39cc7..3a7472fcec08f 100644
--- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
+++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
@@ -1,5 +1,53 @@
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s
+// CHECK-LABEL: func @tensor_from_elements_0d(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+//       CHECK:   %[[empty:.*]] = tensor.empty() : tensor<index>
+//       CHECK:   %[[insert:.*]] = tensor.insert %[[arg0]] into %[[empty]][]
+//       CHECK:   return %[[insert]]
+func.func @tensor_from_elements_0d(%arg0: index) -> tensor<index> {
+  %0 = tensor.from_elements %arg0 : tensor<index>
+  return %0 : tensor<index>
+// -----
+// CHECK-LABEL: func @tensor_from_elements_1d(
+//  CHECK-SAME:     %[[arg0:.*]]: index, %[[arg1:.*]]: index
+//   CHECK-DAG:   %[[empty:.*]] = tensor.empty() : tensor<2xindex>
+//   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   %[[insert:.*]] = tensor.insert %[[arg0]] into %[[empty]][%[[c0]]]
+//       CHECK:   %[[insert2:.*]] = tensor.insert %[[arg1]] into %[[insert]][%[[c1]]]
+//       CHECK:   return %[[insert2]]
+func.func @tensor_from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
+  %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
+  return %0 : tensor<2xindex>
+// -----
+// CHECK-LABEL: func @tensor_from_elements_2d(
+//  CHECK-SAME:     %[[arg0:.*]]: index, %[[arg1:.*]]: index
+//   CHECK-DAG:   %[[empty:.*]] = tensor.empty() : tensor<3x2xindex>
+//   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//       CHECK:   %[[insert0:.*]] = tensor.insert %[[arg0]] into %[[empty]][%[[c0]], %[[c0]]]
+//       CHECK:   %[[insert1:.*]] = tensor.insert %[[arg1]] into %[[insert0]][%[[c0]], %[[c1]]]
+//       CHECK:   %[[insert2:.*]] = tensor.insert %[[arg0]] into %[[insert1]][%[[c1]], %[[c0]]]
+//       CHECK:   %[[insert3:.*]] = tensor.insert %[[arg1]] into %[[insert2]][%[[c1]], %[[c1]]]
+//       CHECK:   %[[insert4:.*]] = tensor.insert %[[arg0]] into %[[insert3]][%[[c2]], %[[c0]]]
+//       CHECK:   %[[insert5:.*]] = tensor.insert %[[arg1]] into %[[insert4]][%[[c2]], %[[c1]]]
+//       CHECK:   return %[[insert5]]
+func.func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
+  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
+         : tensor<3x2xindex>
+  return %0 : tensor<3x2xindex>
+// -----
 // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @tensor_generate(
 //  CHECK-SAME:     %[[s1:.*]]: index, %[[s2:.*]]: index


More information about the Mlir-commits mailing list