[Mlir-commits] [mlir] 884a629 - [mlir][linalg] Add scalar operands inlining pattern
Stephan Herhut
llvmlistbot at llvm.org
Fri May 21 06:25:08 PDT 2021
Author: Stephan Herhut
Date: 2021-05-21T15:23:28+02:00
New Revision: 884a6291f0b995579d2cd203bfdc5b6aa427be31
URL: https://github.com/llvm/llvm-project/commit/884a6291f0b995579d2cd203bfdc5b6aa427be31
DIFF: https://github.com/llvm/llvm-project/commit/884a6291f0b995579d2cd203bfdc5b6aa427be31.diff
LOG: [mlir][linalg] Add scalar operands inlining pattern
This pattern inlines operands to a linalg.generic operation that use a constant
index and hence are loop-invariant scalars. This reduces the number of
linalg.generic operands and unlocks some canonicalizations that rely on seeing
an explicit tensor.extract.
Differential Revision: https://reviews.llvm.org/D102682
Added:
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/IR/AffineMap.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 47dac77701143..804f9c7f34e97 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -36,6 +36,8 @@ std::unique_ptr<OperationPass<FuncOp>>
createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
+std::unique_ptr<OperationPass<FuncOp>> createLinalgInlineScalarOperandsPass();
+
/// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel`
/// loops and memref.load/memref.store accesses.
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTiledLoopsToSCFPass();
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2934f17f1e901..b14efa91e3edb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -84,6 +84,14 @@ def LinalgLowerTiledLoopsToSCF
];
}
+def LinalgInlineScalarOperands : FunctionPass<"linalg-inline-scalar-operands"> {
+ let summary = "Inline scalar operands into linalg generic ops";
+ let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
+ let dependentDialects = [
+ "linalg::LinalgDialect"
+ ];
+}
+
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 192a74495475c..4442e9067ca47 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -78,6 +78,9 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
+/// Patterns that are used to inline constant operands into linalg generic ops.
+void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
+
/// Options that control fusion of elementwise operations.
struct LinalgElementwiseFusionOptions {
/// Enable fusion of reshapes into the shape with elementwise operations. By
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index e9295650761f0..336e6188c69ee 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -135,10 +135,17 @@ class AffineMap {
/// Returns true if this affine map is a single result constant function.
bool isSingleConstant() const;
+ /// Returns true if this affine map has only constant results.
+ bool isConstant() const;
+
/// Returns the constant result of this map. This methods asserts that the map
/// has a single constant result.
int64_t getSingleConstantResult() const;
+ /// Returns the constant results of this map. This method asserts that the map
+ /// has all constant results.
+ SmallVector<int64_t> getConstantResults() const;
+
// Prints affine map to 'os'.
void print(raw_ostream &os) const;
void dump() const;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index b073bf0079f8b..1458c94fc905d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
FusionOnTensors.cpp
Generalization.cpp
Hoisting.cpp
+ InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
Promotion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
new file mode 100644
index 0000000000000..aa01029471ce2
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -0,0 +1,110 @@
+//===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns/pass to inline scalar operands into a generic
+// operation. A scalar operand is an operand whose indexing map has a constant
+// rhs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+
+ SmallVector<size_t> scalarOperands;
+ SmallVector<AffineMap> newIndexingMaps;
+ SmallVector<Value> newOperands;
+ for (auto it : llvm::enumerate(llvm::zip(genericOp.getInputIndexingMaps(),
+ genericOp.getInputTensors()))) {
+ AffineMap map = std::get<0>(it.value());
+ if (map.isConstant()) {
+ scalarOperands.emplace_back(it.index());
+ } else {
+ newIndexingMaps.emplace_back(map);
+ newOperands.emplace_back(std::get<1>(it.value()));
+ }
+ }
+
+ if (scalarOperands.empty())
+ return failure();
+
+ newIndexingMaps.append(genericOp.getOutputIndexingMaps());
+
+ Location loc = genericOp->getLoc();
+ auto newOp = rewriter.create<GenericOp>(
+ loc, genericOp->getResultTypes(), newOperands,
+ genericOp.getOutputTensors(), newIndexingMaps,
+ llvm::to_vector<4>(
+ genericOp.iterator_types().template getAsValueRange<StringAttr>()));
+ rewriter.cloneRegionBefore(genericOp.region(), newOp.region(),
+ newOp.region().begin());
+
+ Block *body = newOp.getBody();
+ PatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(body);
+
+ for (auto idx : llvm::reverse(scalarOperands)) {
+ Value operand = genericOp.getInput(idx);
+ AffineMap map = genericOp.getInputIndexingMap(idx);
+ SmallVector<int64_t> indices = map.getConstantResults();
+ SmallVector<Value> indicesValues;
+ for (auto idx : indices)
+ indicesValues.emplace_back(rewriter.create<ConstantIndexOp>(loc, idx));
+ operand = rewriter.create<tensor::ExtractOp>(loc, operand, indicesValues);
+ body->getArgument(idx).replaceAllUsesWith(operand);
+ body->eraseArgument(idx);
+ }
+
+ rewriter.replaceOp(genericOp, newOp->getResults());
+ return success();
+ }
+};
+} // namespace
+
+/// Patterns that are used to inline constant operands into linalg generic
+/// ops.
+void mlir::linalg::populateInlineConstantOperandsPatterns(
+ RewritePatternSet &patterns) {
+ auto *context = patterns.getContext();
+ patterns.add<InlineScalarOperands>(context);
+}
+
+namespace {
+/// Pass that removes unit-extent dims within generic ops.
+struct LinalgInlineScalarOperandsPass
+ : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> {
+ void runOnFunction() override {
+ FuncOp funcOp = getFunction();
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+
+ populateInlineConstantOperandsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgInlineScalarOperandsPass() {
+ return std::make_unique<LinalgInlineScalarOperandsPass>();
+}
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3c453de10fa75..2713426492dcf 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -287,11 +287,25 @@ bool AffineMap::isSingleConstant() const {
return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
}
+bool AffineMap::isConstant() const {
+ return llvm::all_of(getResults(), [](AffineExpr expr) {
+ return expr.isa<AffineConstantExpr>();
+ });
+}
+
int64_t AffineMap::getSingleConstantResult() const {
assert(isSingleConstant() && "map must have a single constant result");
return getResult(0).cast<AffineConstantExpr>().getValue();
}
+SmallVector<int64_t> AffineMap::getConstantResults() const {
+ assert(isConstant() && "map must have only constant results");
+ SmallVector<int64_t> result;
+ for (auto expr : getResults())
+ result.emplace_back(expr.cast<AffineConstantExpr>().getValue());
+ return result;
+}
+
unsigned AffineMap::getNumDims() const {
assert(map && "uninitialized map storage");
return map->numDims;
diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
new file mode 100644
index 0000000000000..af8fd90b93009
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -linalg-inline-scalar-operands -split-input-file | FileCheck %s
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_zerod(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor<f32>)
+func @inline_zerod(%arg0: tensor<4xf32>, %scalar: tensor<f32>) -> tensor<4xf32> {
+ %0 = linalg.init_tensor [4] : tensor<4xf32>
+ // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+ // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+ %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %scalar : tensor<4xf32>, tensor<f32>)
+ outs(%0 : tensor<4xf32>) {
+ // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32)
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
+ // CHECK: tensor.extract %[[SCALAR]][]
+ %2 = divf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ } -> tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> (0)>
+
+// CHECK: func @inline_oned(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: tensor<1xf32>)
+func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4xf32> {
+ // CHECK: %[[ZERO:.*]] = constant 0 : index
+ %0 = linalg.init_tensor [4] : tensor<4xf32>
+ // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+ // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+ %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %scalar : tensor<4xf32>, tensor<1xf32>)
+ outs(%0 : tensor<4xf32>) {
+ // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32)
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
+ // CHECK: tensor.extract %[[SCALAR]][%[[ZERO]]]
+ %2 = divf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ } -> tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
More information about the Mlir-commits
mailing list