[Mlir-commits] [mlir] 69011a2 - [mlir][Linalg] Make Elementwise op fusion return a map from existing values to values in the fused op.
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue Jan 31 12:34:03 PST 2023
Author: Mahesh Ravishankar
Date: 2023-01-31T20:33:44Z
New Revision: 69011a2ad0ce8662f69d3abef12280b1f463f99c
URL: https://github.com/llvm/llvm-project/commit/69011a2ad0ce8662f69d3abef12280b1f463f99c
DIFF: https://github.com/llvm/llvm-project/commit/69011a2ad0ce8662f69d3abef12280b1f463f99c.diff
LOG: [mlir][Linalg] Make Elementwise op fusion return a map from existing values to values in the fused op.
This replacement can be used to eliminate all uses of the
producer/consumer for case where producer/consumer has other uses
outside of the producer/consumer pair. This makes the
producer/consumer dead.
Add test and minor fixup to the test harness.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D142848
Added:
mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7fe9dfe2fbf5f..953eb59b95134 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -162,8 +162,12 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
/// Fuse two `linalg.generic` operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
-FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
- OpOperand *fusedOperand);
+struct ElementwiseOpFusionResult {
+ Operation *fusedOp;
+ llvm::DenseMap<Value, Value> replacements;
+};
+FailureOr<ElementwiseOpFusionResult>
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// Split the given `op` into two parts along the given iteration space
/// `dimension` at the specified `splitPoint`, and return the two parts.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 2504a2ab0c9bd..8df324dfa3818 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -23,8 +23,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <utility>
#include <optional>
+#include <utility>
namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
@@ -73,6 +73,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
/// Conditions for elementwise fusion of generic operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+ if (!fusedOperand)
+ return false;
+
auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
@@ -270,7 +273,7 @@ static void generateFusedElementwiseOpRegion(
"Ill-formed GenericOp region");
}
-FailureOr<Operation *>
+FailureOr<mlir::linalg::ElementwiseOpFusionResult>
mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand) {
assert(areElementwiseOpsFusable(fusedOperand) &&
@@ -390,7 +393,15 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
generateFusedElementwiseOpRegion(
rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
consumer.getNumLoops(), preservedProducerResults);
- return fusedOp.getOperation();
+ ElementwiseOpFusionResult result;
+ result.fusedOp = fusedOp;
+ int resultNum = 0;
+ for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
+ if (preservedProducerResults.count(index))
+ result.replacements[producerResult] = fusedOp->getResult(resultNum++);
+ for (auto consumerResult : consumer->getResults())
+ result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
+ return result;
}
namespace {
@@ -411,13 +422,20 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
if (!controlFn(&opOperand))
continue;
- FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
- if (succeeded(fusedOp)) {
- auto replacements =
- (*fusedOp)->getResults().take_back(genericOp.getNumResults());
- rewriter.replaceOp(genericOp, replacements);
- return success();
+ FailureOr<ElementwiseOpFusionResult> fusionResult =
+ fuseElementwiseOps(rewriter, &opOperand);
+ if (failed(fusionResult))
+ rewriter.notifyMatchFailure(genericOp, "fusion failed");
+ Operation *producer = opOperand.get().getDefiningOp();
+ for (auto [origVal, replacement] : fusionResult->replacements) {
+ Value origValCopy = origVal;
+ rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) {
+ // Only replace consumer uses.
+ return use.get().getDefiningOp() != producer;
+ });
}
+ rewriter.eraseOp(genericOp);
+ return success();
}
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
new file mode 100644
index 0000000000000..7871ae08fd54a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=fuse-multiuse-producer -split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0:2 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32):
+ %1 = arith.addf %b0, %b1 : f32
+ linalg.yield %1, %1 : f32, f32
+ } -> (tensor<?x?xf32>, tensor<?x?xf32>)
+ %2 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0#1, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg4 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %3 = arith.mulf %b0, %b1 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x?xf32>
+ return %0#0, %0#1, %2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK: func @multi_use_producer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK: %[[RESULT:.+]]:3 = linalg.generic
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e2f61a9611b0d..21c9eaee5213b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -51,6 +51,38 @@ static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
}
namespace {
+
+/// Pattern to test fusion of producer with consumer, even if producer has
+/// multiple uses.
+struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *fusableOperand = nullptr;
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (linalg::areElementwiseOpsFusable(&operand)) {
+ fusableOperand = &operand;
+ break;
+ }
+ }
+ if (!fusableOperand) {
+ return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
+ }
+ std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
+ linalg::fuseElementwiseOps(rewriter, fusableOperand);
+ if (!fusionResult)
+ rewriter.notifyMatchFailure(genericOp, "fusion failed");
+ for (auto [origValue, replacement] : fusionResult->replacements) {
+ rewriter.replaceUseIf(origValue, replacement, [&](OpOperand &use) {
+ return use.getOwner() != genericOp.getOperation();
+ });
+ }
+ rewriter.eraseOp(genericOp);
+ return success();
+ }
+};
+
struct TestLinalgElementwiseFusion
: public PassWrapper<TestLinalgElementwiseFusion,
OperationPass<func::FuncOp>> {
@@ -105,6 +137,12 @@ struct TestLinalgElementwiseFusion
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
+
+ Option<bool> fuseMultiUseProducer{
+ *this, "fuse-multiuse-producer",
+ llvm::cl::desc("Test fusion of producer ops with multiple uses"),
+ llvm::cl::init(false)};
+
ListOption<int64_t> collapseDimensions{
*this, "collapse-dimensions-control",
llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -117,8 +155,9 @@ struct TestLinalgElementwiseFusion
RewritePatternSet fusionPatterns(context);
auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
- std::move(fusionPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
return;
}
@@ -127,8 +166,9 @@ struct TestLinalgElementwiseFusion
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
setFusedOpOperandLimit<4>);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
- std::move(fusionPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
return;
}
@@ -172,8 +212,9 @@ struct TestLinalgElementwiseFusion
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
- std::move(fusionPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
return;
}
@@ -181,7 +222,10 @@ struct TestLinalgElementwiseFusion
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(
patterns, [](OpOperand * /*fusedOperand */) { return true; });
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ return;
}
if (fuseWithReshapeByCollapsingWithControlFn) {
@@ -195,7 +239,19 @@ struct TestLinalgElementwiseFusion
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ return;
+ }
+
+ if (fuseMultiUseProducer) {
+ RewritePatternSet patterns(context);
+ patterns.insert<TestMultiUseProducerFusion>(context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ return;
}
if (!collapseDimensions.empty()) {
@@ -209,7 +265,10 @@ struct TestLinalgElementwiseFusion
};
RewritePatternSet patterns(context);
linalg::populateCollapseDimensions(patterns, collapseFn);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ return;
}
}
};
More information about the Mlir-commits
mailing list