[Mlir-commits] [mlir] e7328a9 - [mlir][linalg] Fold duplicate and unused inputs in linalg.generic
Matthias Springer
llvmlistbot at llvm.org
Fri Dec 9 06:22:03 PST 2022
Author: Matthias Springer
Date: 2022-12-09T15:18:26+01:00
New Revision: e7328a9eb22307d80f86f668a75c2b082ee8636e
URL: https://github.com/llvm/llvm-project/commit/e7328a9eb22307d80f86f668a75c2b082ee8636e
DIFF: https://github.com/llvm/llvm-project/commit/e7328a9eb22307d80f86f668a75c2b082ee8636e.diff
LOG: [mlir][linalg] Fold duplicate and unused inputs in linalg.generic
If an input bbArg is not used, its corresponding input operand is removed. If there are duplicate input operands or input operands that are also used as output operands, the duplicate input operands are removed. Output operands are never modified.
Differential Revision: https://reviews.llvm.org/D139709
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 37143f9084145..4081eb04d8098 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -88,6 +88,10 @@ void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns);
/// This is effectively DCE for a linalg op.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
+/// Patterns to promote inputs to outputs and remove unused inputs of
+/// `linalg.generic` ops.
+void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
+
/// Function type to control generic op dimension collapsing. It is expected
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 87df83fb928f2..defa027517584 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -56,7 +56,9 @@ namespace {
struct DeduplicateAndRemoveDeadOperandsAndResults
: public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
+ DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
+ bool removeOutputs)
+ : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
@@ -120,6 +122,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
}
private:
+ /// If unset, outputs are not modified by this pattern.
+ bool removeOutputs;
+
// Deduplicate input operands, and return the
// - Mapping from operand position in the original op, to operand position in
// the canonicalized op.
@@ -176,9 +181,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
dedupedOutpts;
- // If the op doesnt have tensor semantics, keep all the outputs as
- // preserved.
- if (!genericOp.hasTensorSemantics()) {
+ // If the op doesn't have tensor semantics or outputs should not be removed,
+ // keep all the outputs as preserved.
+ if (!genericOp.hasTensorSemantics() || !removeOutputs) {
for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
origToNewPos[en.index()] = newOutputOperands.size();
newOutputOperands.push_back(en.value()->get());
@@ -353,10 +358,69 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
return failure();
}
};
+
+/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
+/// ```
+/// linalg.generic ins(%a, %b, %a, %b) outs(%a)
+/// ^bb0(%in0, %in1, %in2, %in3, %out1)
+/// ```
+/// Assuming that all %a and %b have the same index map:
+/// * All uses of %in0 and %in2 are replaced with %out1
+/// * All uses of %in1 are replaced with %in3
+/// This pattern can enable additional canonicalizations: In the above example,
+/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
+/// can be folded away. This pattern does not modify uses of output block args.
+struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ // Find replacement bbArgs for all input bbArg.
+ DenseMap<int, int> replacements;
+ for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
+ // Skip bbArgs that have no uses.
+ if (genericOp.getBody()->getArgument(i).getUses().empty())
+ continue;
+ // Find replacement bbArg. This can be an input or an output bbArg.
+ for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
+ if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
+ genericOp.getIndexingMapsArray()[i] ==
+ genericOp.getIndexingMapsArray()[j]) {
+ replacements[i] = j;
+ break;
+ }
+ }
+ }
+
+ // Stop here if no replacements were found.
+ if (replacements.empty())
+ return failure();
+
+ // Rewrite the op.
+ rewriter.updateRootInPlace(genericOp, [&]() {
+ for (auto [before, after] : replacements) {
+ BlockArgument bbArg = genericOp.getBody()->getArgument(before);
+ BlockArgument replacement = genericOp.getBody()->getArgument(after);
+ rewriter.replaceAllUsesWith(bbArg, replacement);
+ }
+ });
+
+ return success();
+ }
+};
+
} // namespace
void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
RewritePatternSet &patterns) {
- patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults,
- RemoveUnusedCycleInGenericOp>(patterns.getContext());
+ patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
+ patterns.getContext(), /*removeOutputs=*/true);
+ patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
+}
+
+void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
+ patterns.getContext(), /*removeOutputs=*/false);
+ patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir b/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir
index 8bb6153fa88b0..dea3c22be0015 100644
--- a/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir
+++ b/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unused-operands-and-results | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unnecessary-inputs | FileCheck %s --check-prefix=CHECK-INPUT
// CHECK-LABEL: func @remove_deadargs_generic_basic
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
@@ -493,3 +494,29 @@ func.func @drop_only_the_cycles_not_used_by_others(%arg0 : tensor<?x?x?xf32>) ->
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
// CHECK-SAME: outs(%[[ARG0]], %[[INIT]] :
// CHECK: return %[[GENERIC]]#0
+
+
+// -----
+
+// CHECK-INPUT-LABEL: func @remove_unnecessary_input(
+// CHECK-INPUT-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>
+#map = affine_map<(d0) -> (d0)>
+func.func @remove_unnecessary_input(%a: tensor<?xf32>, %b: tensor<?xf32>)
+ -> tensor<?xf32>
+{
+ // CHECK-INPUT: %[[result:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel"]}
+ // CHECK-INPUT-SAME: ins(%[[a]] : tensor<?xf32>) outs(%[[b]] : tensor<?xf32>) {
+ // CHECK-INPUT: ^bb0(%[[in:.*]]: f32, %[[out:.*]]: f32):
+ // CHECK-INPUT: %[[add:.*]] = arith.addf %[[in]], %[[out]]
+ // CHECK-INPUT: linalg.yield %[[add]]
+ // CHECK-INPUT: } -> tensor<?xf32>
+ // CHECK-INPUT: return %[[result]]
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
+ ins(%a, %b : tensor<?xf32>, tensor<?xf32>) outs(%b : tensor<?xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %16 = arith.addf %in, %in_2 : f32
+ linalg.yield %16 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 92ee447792ad4..892e04bbb816f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -113,6 +113,10 @@ struct TestLinalgTransforms
*this, "test-erase-unused-operands-and-results",
llvm::cl::desc("Test patterns to erase unused operands and results"),
llvm::cl::init(false)};
+ Option<bool> testEraseUnnecessaryInputs{
+ *this, "test-erase-unnecessary-inputs",
+ llvm::cl::desc("Test patterns to erase unnecessary inputs"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -185,6 +189,12 @@ static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ populateEraseUnnecessaryInputsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
@@ -205,6 +215,8 @@ void TestLinalgTransforms::runOnOperation() {
return applySwapExtractSliceWithFillPattern(getOperation());
if (testEraseUnusedOperandsAndResults)
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
+ if (testEraseUnnecessaryInputs)
+ return applyEraseUnnecessaryInputs(getOperation());
}
namespace mlir {
More information about the Mlir-commits
mailing list