[Mlir-commits] [mlir] 45ccff1 - [mlir][linalg] Convert tensor.generate to destination style
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 24 00:18:07 PST 2023
Author: Matthias Springer
Date: 2023-01-24T09:13:08+01:00
New Revision: 45ccff175b38b96251f1ba065c1c22b7b366f25d
URL: https://github.com/llvm/llvm-project/commit/45ccff175b38b96251f1ba065c1c22b7b366f25d
DIFF: https://github.com/llvm/llvm-project/commit/45ccff175b38b96251f1ba065c1c22b7b366f25d.diff
LOG: [mlir][linalg] Convert tensor.generate to destination style
This can be a pre-processing for bufferization and allows for more efficient lowerings without an alloc.
Differential Revision: https://reviews.llvm.org/D142205
Added:
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 286368550aaad..7e8507d3c9df2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -55,6 +55,10 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
bool removeDeadArgsAndResults = true);
+/// Populate patterns that convert non-destination-style ops to destination
+/// style ops.
+void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns);
+
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4ca9f617adc3e..762f2f1dad506 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ConstantFold.cpp
+ ConvertToDestinationStyle.cpp
DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
Detensorize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
new file mode 100644
index 0000000000000..859657cbfaec0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -0,0 +1,78 @@
+//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns to convert non-DPS ops to DPS ops. New
+// tensor.empty ops are inserted as a destination. Such tensor.empty can be
+// eliminated with "empty tensor elimination", allowing them to bufferize
+// without an allocation (assuming there are no further conflicts).
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Lower tensor.generate to linalg.generic.
+struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
+ using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenerateOp generateOp,
+ PatternRewriter &rewriter) const override {
+ // Only ops with exactly one block are supported.
+ if (!generateOp.getBody().hasOneBlock())
+ return failure();
+
+ Location loc = generateOp.getLoc();
+ RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
+
+ // Create tensor.empty.
+ auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType,
+ generateOp.getDynamicExtents());
+
+ // Create linalg.generic.
+ SmallVector<utils::IteratorType> iteratorTypes(
+ tensorType.getRank(), utils::IteratorType::parallel);
+ SmallVector<AffineMap> indexingMaps(
+ 1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, tensorType, /*inputs=*/ValueRange(),
+ /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
+ indexingMaps, iteratorTypes);
+ Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
+ tensorType.getElementType(), loc);
+ rewriter.setInsertionPointToStart(body);
+ SmallVector<Value> bbArgReplacements;
+ for (int64_t i = 0; i < tensorType.getRank(); ++i)
+ bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+ rewriter.mergeBlocks(&generateOp.getBody().front(), body,
+ bbArgReplacements);
+
+ // Update terminator.
+ auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
+ rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+
+ // Replace tensor.generate.
+ rewriter.replaceOp(generateOp, genericOp->getResult(0));
+ return success();
+ }
+};
+
+} // namespace
+
+void linalg::populateConvertToDestinationStylePatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<GenerateOpConverter>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
new file mode 100644
index 0000000000000..61df430780377
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns %s | FileCheck %s
+
+// CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @tensor_generate(
+// CHECK-SAME: %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK: %[[empty:.*]] = tensor.empty(%[[s1]], %[[s2]]) : tensor<?x?xindex>
+// CHECK: %[[generic:.*]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]]], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: outs(%[[empty]] : tensor<?x?xindex>) {
+// CHECK: %[[i0:.*]] = linalg.index 0
+// CHECK: %[[i1:.*]] = linalg.index 1
+// CHECK: %[[added:.*]] = arith.addi %[[i0]], %[[i1]]
+// CHECK: linalg.yield %[[added]]
+// CHECK: }
+// CHECK: return %[[generic]]
+func.func @tensor_generate(%s1: index, %s2: index) -> tensor<?x?xindex> {
+ %0 = tensor.generate %s1, %s2 {
+ ^bb0(%arg0: index, %arg1: index):
+ %1 = arith.addi %arg0, %arg1 : index
+ tensor.yield %1 : index
+ } : tensor<?x?xindex>
+ return %0 : tensor<?x?xindex>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 5ce43ff99232b..7842e860c6a67 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -128,6 +128,10 @@ struct TestLinalgTransforms
*this, "test-erase-unnecessary-inputs",
llvm::cl::desc("Test patterns to erase unnecessary inputs"),
llvm::cl::init(false)};
+ Option<bool> testConvertToDestinationStylePatterns{
+ *this, "test-convert-to-destination-style-patterns",
+ llvm::cl::desc("Test patterns that convert ops to destination style"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -218,6 +222,12 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyConvertToDestinationStylePatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ populateConvertToDestinationStylePatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
@@ -244,6 +254,8 @@ void TestLinalgTransforms::runOnOperation() {
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
if (testEraseUnnecessaryInputs)
return applyEraseUnnecessaryInputs(getOperation());
+ if (testConvertToDestinationStylePatterns)
+ applyConvertToDestinationStylePatterns(getOperation());
}
namespace mlir {
More information about the Mlir-commits
mailing list