[Mlir-commits] [mlir] 45ccff1 - [mlir][linalg] Convert tensor.generate to destination style

Matthias Springer llvmlistbot at llvm.org
Tue Jan 24 00:18:07 PST 2023


Author: Matthias Springer
Date: 2023-01-24T09:13:08+01:00
New Revision: 45ccff175b38b96251f1ba065c1c22b7b366f25d

URL: https://github.com/llvm/llvm-project/commit/45ccff175b38b96251f1ba065c1c22b7b366f25d
DIFF: https://github.com/llvm/llvm-project/commit/45ccff175b38b96251f1ba065c1c22b7b366f25d.diff

LOG: [mlir][linalg] Convert tensor.generate to destination style

This can be a pre-processing for bufferization and allows for more efficient lowerings without an alloc.

Differential Revision: https://reviews.llvm.org/D142205

Added: 
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
    mlir/test/Dialect/Linalg/convert-to-destination-style.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 286368550aaad..7e8507d3c9df2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -55,6 +55,10 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
 void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
                                        bool removeDeadArgsAndResults = true);
 
+/// Populate patterns that convert non-destination-style ops to destination
+/// style ops.
+void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns);
+
 /// Populate patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
 /// were decomposed previously.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4ca9f617adc3e..762f2f1dad506 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ConstantFold.cpp
+  ConvertToDestinationStyle.cpp
   DataLayoutPropagation.cpp
   DecomposeLinalgOps.cpp
   Detensorize.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
new file mode 100644
index 0000000000000..859657cbfaec0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -0,0 +1,78 @@
+//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns to convert non-DPS ops to DPS ops. New
+// tensor.empty ops are inserted as a destination. Such tensor.empty can be
+// eliminated with "empty tensor elimination", allowing them to bufferize
+// without an allocation (assuming there are no further conflicts).
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Lower tensor.generate to linalg.generic.
+struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
+  using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenerateOp generateOp,
+                                PatternRewriter &rewriter) const override {
+    // Only ops with exactly one block are supported.
+    if (!generateOp.getBody().hasOneBlock())
+      return failure();
+
+    Location loc = generateOp.getLoc();
+    RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
+
+    // Create tensor.empty.
+    auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType,
+                                            generateOp.getDynamicExtents());
+
+    // Create linalg.generic.
+    SmallVector<utils::IteratorType> iteratorTypes(
+        tensorType.getRank(), utils::IteratorType::parallel);
+    SmallVector<AffineMap> indexingMaps(
+        1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, tensorType, /*inputs=*/ValueRange(),
+        /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
+        indexingMaps, iteratorTypes);
+    Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
+                                       tensorType.getElementType(), loc);
+    rewriter.setInsertionPointToStart(body);
+    SmallVector<Value> bbArgReplacements;
+    for (int64_t i = 0; i < tensorType.getRank(); ++i)
+      bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+    rewriter.mergeBlocks(&generateOp.getBody().front(), body,
+                         bbArgReplacements);
+
+    // Update terminator.
+    auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
+    rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+
+    // Replace tensor.generate.
+    rewriter.replaceOp(generateOp, genericOp->getResult(0));
+    return success();
+  }
+};
+
+} // namespace
+
+void linalg::populateConvertToDestinationStylePatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<GenerateOpConverter>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
new file mode 100644
index 0000000000000..61df430780377
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns %s | FileCheck %s
+
+// CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @tensor_generate(
+//  CHECK-SAME:     %[[s1:.*]]: index, %[[s2:.*]]: index
+//       CHECK:   %[[empty:.*]] = tensor.empty(%[[s1]], %[[s2]]) : tensor<?x?xindex>
+//       CHECK:   %[[generic:.*]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]]], iterator_types = ["parallel", "parallel"]}
+//  CHECK-SAME:       outs(%[[empty]] : tensor<?x?xindex>) {
+//       CHECK:     %[[i0:.*]] = linalg.index 0
+//       CHECK:     %[[i1:.*]] = linalg.index 1
+//       CHECK:     %[[added:.*]] = arith.addi %[[i0]], %[[i1]]
+//       CHECK:     linalg.yield %[[added]]
+//       CHECK:   }
+//       CHECK:   return %[[generic]]
+func.func @tensor_generate(%s1: index, %s2: index) -> tensor<?x?xindex> {
+  %0 = tensor.generate %s1, %s2 {
+    ^bb0(%arg0: index, %arg1: index):
+    %1 = arith.addi %arg0, %arg1 : index
+    tensor.yield %1 : index
+  } : tensor<?x?xindex>
+  return %0 : tensor<?x?xindex>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 5ce43ff99232b..7842e860c6a67 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -128,6 +128,10 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
+  Option<bool> testConvertToDestinationStylePatterns{
+      *this, "test-convert-to-destination-style-patterns",
+      llvm::cl::desc("Test patterns that convert ops to destination style"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -218,6 +222,12 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyConvertToDestinationStylePatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  populateConvertToDestinationStylePatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -244,6 +254,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
+  if (testConvertToDestinationStylePatterns)
+    applyConvertToDestinationStylePatterns(getOperation());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list