[Mlir-commits] [mlir] [mlir][linalg] Add producer and consumer info to ControlPropagationFn. (PR #96697)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Jun 25 13:56:31 PDT 2024
https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/96697
It's not easy to determine whether we want to propagate pack/unpack ops because we don't know the (producer, consumer) infomation. Exposes both operations to the control function helps a lot to make the decision in downstream projects.
>From a6f4f768965237739ffe692ef958cd27777c3bb7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 25 Jun 2024 13:52:17 -0700
Subject: [PATCH] [mlir][linalg] Add producer and consumer info to
ControlPropagationFn.
It's not easy to determine whether we want to propagate pack/unpack ops
because we don't know the (producer, consumer) infomation. Exposes both
operations to the control function helps a lot to make the decision in
downstream projects.
---
.../Dialect/Linalg/Transforms/Transforms.h | 3 ++-
.../Transforms/DataLayoutPropagation.cpp | 23 +++++++++++--------
.../Linalg/TestDataLayoutPropagation.cpp | 3 ++-
3 files changed, 17 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b0871a5dff5da..e14e1b988ac28 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1630,7 +1630,8 @@ void populateElementwiseOpsFusionPatterns(
/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
-using ControlPropagationFn = std::function<bool(Operation *op)>;
+using ControlPropagationFn =
+ std::function<bool(Operation *producer, Operation *consumer)>;
/// Patterns to bubble up or down data layout ops across other operations.
void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..f87529211dce6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
return failure();
// User controlled propagation function.
- if (!controlFn(genericOp))
+ if (!controlFn(genericOp, packOp))
return failure();
// TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
return failure();
// User controlled propagation function.
- if (!controlFn(padOp))
+ if (!controlFn(padOp, packOp))
return failure();
if (!padOp.getResult().hasOneUse())
@@ -843,7 +843,7 @@ class BubbleUpPackOpThroughReshapeOp final
}
// User controlled propagation function.
- if (!controlFn(srcOp))
+ if (!controlFn(srcOp, packOp))
return failure();
return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -970,7 +970,7 @@ class PushDownUnPackOpThroughReshapeOp final
Operation *consumerOp = *result.user_begin();
// User controlled propagation function.
- if (!controlFn(consumerOp))
+ if (!controlFn(unPackOp, consumerOp))
return failure();
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
@@ -1037,7 +1037,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
/// inner_dims_pos = [3] inner_tiles = [32] into %0
///
static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ ControlPropagationFn controlFn) {
if (genericOp.getNumResults() != 1)
return failure();
@@ -1054,6 +1055,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
tensor::UnPackOp producerUnPackOp =
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
assert(producerUnPackOp && "expect a valid UnPackOp");
+
+ if (!controlFn(producerUnPackOp, genericOp))
+ return failure();
+
auto packInfo =
getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
if (failed(packInfo))
@@ -1121,10 +1126,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!controlFn(genericOp))
- return failure();
-
- auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
+ auto genericAndRepl =
+ pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1149,7 +1152,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
if (!unpackOp)
return failure();
- if (!controlFn(padOp))
+ if (!controlFn(unpackOp, padOp))
return failure();
Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b5998d9c851e4..55d0715bbb9bb 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
- patterns, [](Operation *op) { return true; });
+ patterns,
+ [](Operation *producer, Operation *consumer) { return true; });
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list