[Mlir-commits] [mlir] 884a629 - [mlir][linalg] Add scalar operands inlining pattern

Stephan Herhut llvmlistbot at llvm.org
Fri May 21 06:25:08 PDT 2021


Author: Stephan Herhut
Date: 2021-05-21T15:23:28+02:00
New Revision: 884a6291f0b995579d2cd203bfdc5b6aa427be31

URL: https://github.com/llvm/llvm-project/commit/884a6291f0b995579d2cd203bfdc5b6aa427be31
DIFF: https://github.com/llvm/llvm-project/commit/884a6291f0b995579d2cd203bfdc5b6aa427be31.diff

LOG: [mlir][linalg] Add scalar operands inlining pattern

This pattern inlines operands to a linalg.generic operation that use a constant
index and hence are loop-invariant scalars. This reduces the number of
linalg.generic operands and unlocks some canonicalizations that rely on seeing
an explicit tensor.extract.

Differential Revision: https://reviews.llvm.org/D102682

Added: 
    mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
    mlir/test/Dialect/Linalg/inline-scalar-operands.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/IR/AffineMap.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 47dac77701143..804f9c7f34e97 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -36,6 +36,8 @@ std::unique_ptr<OperationPass<FuncOp>>
 createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
 std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
 
+std::unique_ptr<OperationPass<FuncOp>> createLinalgInlineScalarOperandsPass();
+
 /// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel`
 /// loops and memref.load/memref.store accesses.
 std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTiledLoopsToSCFPass();

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2934f17f1e901..b14efa91e3edb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -84,6 +84,14 @@ def LinalgLowerTiledLoopsToSCF
   ];
 }
 
+def LinalgInlineScalarOperands : FunctionPass<"linalg-inline-scalar-operands"> {
+  let summary = "Inline scalar operands into linalg generic ops";
+  let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
+  let dependentDialects = [
+    "linalg::LinalgDialect"
+  ];
+}
+
 def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 192a74495475c..4442e9067ca47 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -78,6 +78,9 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
 /// tensors.
 void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
 
+/// Patterns that are used to inline constant operands into linalg generic ops.
+void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
+
 /// Options that control fusion of elementwise operations.
 struct LinalgElementwiseFusionOptions {
   /// Enable fusion of reshapes into the shape with elementwise operations. By

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index e9295650761f0..336e6188c69ee 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -135,10 +135,17 @@ class AffineMap {
   /// Returns true if this affine map is a single result constant function.
   bool isSingleConstant() const;
 
+  /// Returns true if this affine map has only constant results.
+  bool isConstant() const;
+
   /// Returns the constant result of this map. This methods asserts that the map
   /// has a single constant result.
   int64_t getSingleConstantResult() const;
 
+  /// Returns the constant results of this map. This method asserts that the map
+  /// has all constant results.
+  SmallVector<int64_t> getConstantResults() const;
+
   // Prints affine map to 'os'.
   void print(raw_ostream &os) const;
   void dump() const;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index b073bf0079f8b..1458c94fc905d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   FusionOnTensors.cpp
   Generalization.cpp
   Hoisting.cpp
+  InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
   Promotion.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
new file mode 100644
index 0000000000000..aa01029471ce2
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -0,0 +1,110 @@
+//===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns/pass to inline scalar operands into a generic
+// operation. A scalar operand is an operand whose indexing map has a constant
+// rhs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+
+    SmallVector<size_t> scalarOperands;
+    SmallVector<AffineMap> newIndexingMaps;
+    SmallVector<Value> newOperands;
+    for (auto it : llvm::enumerate(llvm::zip(genericOp.getInputIndexingMaps(),
+                                             genericOp.getInputTensors()))) {
+      AffineMap map = std::get<0>(it.value());
+      if (map.isConstant()) {
+        scalarOperands.emplace_back(it.index());
+      } else {
+        newIndexingMaps.emplace_back(map);
+        newOperands.emplace_back(std::get<1>(it.value()));
+      }
+    }
+
+    if (scalarOperands.empty())
+      return failure();
+
+    newIndexingMaps.append(genericOp.getOutputIndexingMaps());
+
+    Location loc = genericOp->getLoc();
+    auto newOp = rewriter.create<GenericOp>(
+        loc, genericOp->getResultTypes(), newOperands,
+        genericOp.getOutputTensors(), newIndexingMaps,
+        llvm::to_vector<4>(
+            genericOp.iterator_types().template getAsValueRange<StringAttr>()));
+    rewriter.cloneRegionBefore(genericOp.region(), newOp.region(),
+                               newOp.region().begin());
+
+    Block *body = newOp.getBody();
+    PatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(body);
+
+    for (auto idx : llvm::reverse(scalarOperands)) {
+      Value operand = genericOp.getInput(idx);
+      AffineMap map = genericOp.getInputIndexingMap(idx);
+      SmallVector<int64_t> indices = map.getConstantResults();
+      SmallVector<Value> indicesValues;
+      for (auto idx : indices)
+        indicesValues.emplace_back(rewriter.create<ConstantIndexOp>(loc, idx));
+      operand = rewriter.create<tensor::ExtractOp>(loc, operand, indicesValues);
+      body->getArgument(idx).replaceAllUsesWith(operand);
+      body->eraseArgument(idx);
+    }
+
+    rewriter.replaceOp(genericOp, newOp->getResults());
+    return success();
+  }
+};
+} // namespace
+
+/// Patterns that are used to inline constant operands into linalg generic
+/// ops.
+void mlir::linalg::populateInlineConstantOperandsPatterns(
+    RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.add<InlineScalarOperands>(context);
+}
+
+namespace {
+/// Pass that removes unit-extent dims within generic ops.
+struct LinalgInlineScalarOperandsPass
+    : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> {
+  void runOnFunction() override {
+    FuncOp funcOp = getFunction();
+    MLIRContext *context = funcOp.getContext();
+    RewritePatternSet patterns(context);
+
+    populateInlineConstantOperandsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgInlineScalarOperandsPass() {
+  return std::make_unique<LinalgInlineScalarOperandsPass>();
+}

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3c453de10fa75..2713426492dcf 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -287,11 +287,25 @@ bool AffineMap::isSingleConstant() const {
   return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
 }
 
+bool AffineMap::isConstant() const {
+  return llvm::all_of(getResults(), [](AffineExpr expr) {
+    return expr.isa<AffineConstantExpr>();
+  });
+}
+
 int64_t AffineMap::getSingleConstantResult() const {
   assert(isSingleConstant() && "map must have a single constant result");
   return getResult(0).cast<AffineConstantExpr>().getValue();
 }
 
+SmallVector<int64_t> AffineMap::getConstantResults() const {
+  assert(isConstant() && "map must have only constant results");
+  SmallVector<int64_t> result;
+  for (auto expr : getResults())
+    result.emplace_back(expr.cast<AffineConstantExpr>().getValue());
+  return result;
+}
+
 unsigned AffineMap::getNumDims() const {
   assert(map && "uninitialized map storage");
   return map->numDims;

diff  --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
new file mode 100644
index 0000000000000..af8fd90b93009
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -linalg-inline-scalar-operands -split-input-file | FileCheck %s
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_zerod(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor<f32>)
+func @inline_zerod(%arg0: tensor<4xf32>, %scalar: tensor<f32>) -> tensor<4xf32> {
+    %0 = linalg.init_tensor [4] : tensor<4xf32>
+    // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+    // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+    %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+                         iterator_types = ["parallel"]}
+                         ins(%arg0, %scalar : tensor<4xf32>, tensor<f32>)
+                         outs(%0 : tensor<4xf32>) {
+    // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32)
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
+      // CHECK: tensor.extract %[[SCALAR]][]
+      %2 = divf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> (0)>
+
+// CHECK: func @inline_oned(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor<1xf32>)
+func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4xf32> {
+    // CHECK: %[[ZERO:.*]] = constant 0 : index
+    %0 = linalg.init_tensor [4] : tensor<4xf32>
+    // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+    // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+    %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+                         iterator_types = ["parallel"]}
+                         ins(%arg0, %scalar : tensor<4xf32>, tensor<1xf32>)
+                         outs(%0 : tensor<4xf32>) {
+    // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32)
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
+      // CHECK: tensor.extract %[[SCALAR]][%[[ZERO]]]
+      %2 = divf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}


        


More information about the Mlir-commits mailing list