[Mlir-commits] [mlir] 124fce0 - [mlir][linalg] Convert tensor.from_elements to destination style
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 25 00:18:45 PST 2023
Author: Matthias Springer
Date: 2023-01-25T09:18:38+01:00
New Revision: 124fce09a24e7ca524cb47e3f1c6047fdfca739f
URL: https://github.com/llvm/llvm-project/commit/124fce09a24e7ca524cb47e3f1c6047fdfca739f
DIFF: https://github.com/llvm/llvm-project/commit/124fce09a24e7ca524cb47e3f1c6047fdfca739f.diff
LOG: [mlir][linalg] Convert tensor.from_elements to destination style
This can be a pre-processing for bufferization and allows for more efficient lowerings without an alloc.
Differential Revision: https://reviews.llvm.org/D142206
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index e2234b541fa9c..f0f7187804bbb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -24,8 +24,72 @@
using namespace mlir;
using namespace mlir::tensor;
+// Implements backtracking to traverse indices of the output buffer while
+// iterating over op.elements().
+static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
+ Value destination, ArrayRef<int64_t> shape,
+ ArrayRef<Value> constants,
+ OperandRange::iterator &elementIt,
+ SmallVectorImpl<Value> &indices) {
+ if (dim == static_cast<int>(shape.size()) - 1) {
+ for (int i = 0; i < shape.back(); ++i) {
+ indices.back() = constants[i];
+ destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
+ destination, indices);
+ ++elementIt;
+ }
+ return destination;
+ }
+ for (int i = 0; i < shape[dim]; ++i) {
+ indices[dim] = constants[i];
+ destination = createInserts(rewriter, loc, dim + 1, destination, shape,
+ constants, elementIt, indices);
+ }
+ return destination;
+}
+
namespace {
+/// Lower tensor.from_elements to a sequence of chained tensor.insert.
+struct FromElementsOpConverter : public OpRewritePattern<FromElementsOp> {
+ using OpRewritePattern<FromElementsOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FromElementsOp elementsOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = elementsOp.getLoc();
+ RankedTensorType tensorType = elementsOp.getType().cast<RankedTensorType>();
+ auto shape = tensorType.getShape();
+
+ // Create tensor.empty.
+ auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
+
+ // Case: tensor<elem_type>.
+ if (shape.empty()) {
+ rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+ elementsOp, elementsOp.getElements().front(), emptyOp.getResult(),
+ ValueRange());
+ return success();
+ }
+
+ // Create constants for the range of possible indices [0, max{shape_i}).
+ auto maxDim = *std::max_element(shape.begin(), shape.end());
+ SmallVector<Value, 2> constants;
+ constants.reserve(maxDim);
+ for (int i = 0; i < maxDim; ++i)
+ constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+ // Traverse all elements and create tensor.insert ops.
+ auto elementIt = elementsOp.getElements().begin();
+ SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+ Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
+ shape, constants, elementIt, indices);
+
+ // Replace tensor.from_elements.
+ rewriter.replaceOp(elementsOp, result);
+ return success();
+ }
+};
+
/// Lower tensor.generate to linalg.generic.
struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
@@ -172,5 +236,6 @@ struct PadOpConverter : public OpRewritePattern<PadOp> {
void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
- patterns.insert<GenerateOpConverter, PadOpConverter>(patterns.getContext());
+ patterns.insert<FromElementsOpConverter, GenerateOpConverter, PadOpConverter>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
index 4a29865e39cc7..3a7472fcec08f 100644
--- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
+++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
@@ -1,5 +1,53 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s
+// CHECK-LABEL: func @tensor_from_elements_0d(
+// CHECK-SAME: %[[arg0:.*]]: index
+// CHECK: %[[empty:.*]] = tensor.empty() : tensor<index>
+// CHECK: %[[insert:.*]] = tensor.insert %[[arg0]] into %[[empty]][]
+// CHECK: return %[[insert]]
+func.func @tensor_from_elements_0d(%arg0: index) -> tensor<index> {
+ %0 = tensor.from_elements %arg0 : tensor<index>
+ return %0 : tensor<index>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_from_elements_1d(
+// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2xindex>
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[insert:.*]] = tensor.insert %[[arg0]] into %[[empty]][%[[c0]]]
+// CHECK: %[[insert2:.*]] = tensor.insert %[[arg1]] into %[[insert]][%[[c1]]]
+// CHECK: return %[[insert2]]
+func.func @tensor_from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
+ %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
+ return %0 : tensor<2xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_from_elements_2d(
+// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<3x2xindex>
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[insert0:.*]] = tensor.insert %[[arg0]] into %[[empty]][%[[c0]], %[[c0]]]
+// CHECK: %[[insert1:.*]] = tensor.insert %[[arg1]] into %[[insert0]][%[[c0]], %[[c1]]]
+// CHECK: %[[insert2:.*]] = tensor.insert %[[arg0]] into %[[insert1]][%[[c1]], %[[c0]]]
+// CHECK: %[[insert3:.*]] = tensor.insert %[[arg1]] into %[[insert2]][%[[c1]], %[[c1]]]
+// CHECK: %[[insert4:.*]] = tensor.insert %[[arg0]] into %[[insert3]][%[[c2]], %[[c0]]]
+// CHECK: %[[insert5:.*]] = tensor.insert %[[arg1]] into %[[insert4]][%[[c2]], %[[c1]]]
+// CHECK: return %[[insert5]]
+func.func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
+ %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
+ : tensor<3x2xindex>
+ return %0 : tensor<3x2xindex>
+}
+
+// -----
+
// CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @tensor_generate(
// CHECK-SAME: %[[s1:.*]]: index, %[[s2:.*]]: index
More information about the Mlir-commits
mailing list