[Mlir-commits] [mlir] [mlir][linalg] Add producer and consumer info to ControlPropagationFn. (PR #96697)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Jul 1 13:56:37 PDT 2024
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/96697
>From a6f4f768965237739ffe692ef958cd27777c3bb7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 25 Jun 2024 13:52:17 -0700
Subject: [PATCH 1/2] [mlir][linalg] Add producer and consumer info to
ControlPropagationFn.
It's not easy to determine whether we want to propagate pack/unpack ops
because we don't know the (producer, consumer) infomation. Exposes both
operations to the control function helps a lot to make the decision in
downstream projects.
---
.../Dialect/Linalg/Transforms/Transforms.h | 3 ++-
.../Transforms/DataLayoutPropagation.cpp | 23 +++++++++++--------
.../Linalg/TestDataLayoutPropagation.cpp | 3 ++-
3 files changed, 17 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b0871a5dff5da..e14e1b988ac28 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1630,7 +1630,8 @@ void populateElementwiseOpsFusionPatterns(
/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
-using ControlPropagationFn = std::function<bool(Operation *op)>;
+using ControlPropagationFn =
+ std::function<bool(Operation *producer, Operation *consumer)>;
/// Patterns to bubble up or down data layout ops across other operations.
void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..f87529211dce6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
return failure();
// User controlled propagation function.
- if (!controlFn(genericOp))
+ if (!controlFn(genericOp, packOp))
return failure();
// TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
return failure();
// User controlled propagation function.
- if (!controlFn(padOp))
+ if (!controlFn(padOp, packOp))
return failure();
if (!padOp.getResult().hasOneUse())
@@ -843,7 +843,7 @@ class BubbleUpPackOpThroughReshapeOp final
}
// User controlled propagation function.
- if (!controlFn(srcOp))
+ if (!controlFn(srcOp, packOp))
return failure();
return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -970,7 +970,7 @@ class PushDownUnPackOpThroughReshapeOp final
Operation *consumerOp = *result.user_begin();
// User controlled propagation function.
- if (!controlFn(consumerOp))
+ if (!controlFn(unPackOp, consumerOp))
return failure();
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
@@ -1037,7 +1037,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
/// inner_dims_pos = [3] inner_tiles = [32] into %0
///
static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ ControlPropagationFn controlFn) {
if (genericOp.getNumResults() != 1)
return failure();
@@ -1054,6 +1055,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
tensor::UnPackOp producerUnPackOp =
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
assert(producerUnPackOp && "expect a valid UnPackOp");
+
+ if (!controlFn(producerUnPackOp, genericOp))
+ return failure();
+
auto packInfo =
getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
if (failed(packInfo))
@@ -1121,10 +1126,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!controlFn(genericOp))
- return failure();
-
- auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
+ auto genericAndRepl =
+ pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1149,7 +1152,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
if (!unpackOp)
return failure();
- if (!controlFn(padOp))
+ if (!controlFn(unpackOp, padOp))
return failure();
Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b5998d9c851e4..55d0715bbb9bb 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
- patterns, [](Operation *op) { return true; });
+ patterns,
+ [](Operation *producer, Operation *consumer) { return true; });
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
>From cf92e50d2c802b66f0d27a682c56f9bc5d8644ef Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 1 Jul 2024 13:56:09 -0700
Subject: [PATCH 2/2] Switch to use `OpOperand*`
---
.../Dialect/Linalg/Transforms/Transforms.h | 3 +-
.../Transforms/DataLayoutPropagation.cpp | 28 +++++++++----------
.../Linalg/TestDataLayoutPropagation.cpp | 3 +-
3 files changed, 16 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e14e1b988ac28..e080afaa3d66a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1630,8 +1630,7 @@ void populateElementwiseOpsFusionPatterns(
/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
-using ControlPropagationFn =
- std::function<bool(Operation *producer, Operation *consumer)>;
+using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
/// Patterns to bubble up or down data layout ops across other operations.
void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f87529211dce6..5ae85e6ae2e6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
return failure();
// User controlled propagation function.
- if (!controlFn(genericOp, packOp))
+ if (!controlFn(&packOp.getSourceMutable()))
return failure();
// TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
return failure();
// User controlled propagation function.
- if (!controlFn(padOp, packOp))
+ if (!controlFn(&packOp.getSourceMutable()))
return failure();
if (!padOp.getResult().hasOneUse())
@@ -843,7 +843,7 @@ class BubbleUpPackOpThroughReshapeOp final
}
// User controlled propagation function.
- if (!controlFn(srcOp, packOp))
+ if (!controlFn(&packOp.getSourceMutable()))
return failure();
return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -879,10 +879,13 @@ class BubbleUpPackOpThroughReshapeOp final
/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
-static LogicalResult
-pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
- tensor::ExpandShapeOp expandOp,
- PatternRewriter &rewriter) {
+static LogicalResult pushDownUnPackOpThroughExpandShape(
+ tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+ // User controlled propagation function.
+ if (!controlFn(expandOp.getSrcMutable()))
+ return failure();
+
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -969,13 +972,10 @@ class PushDownUnPackOpThroughReshapeOp final
}
Operation *consumerOp = *result.user_begin();
- // User controlled propagation function.
- if (!controlFn(unPackOp, consumerOp))
- return failure();
-
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
.Case([&](tensor::ExpandShapeOp op) {
- return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+ return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
+ controlFn);
})
.Default([](Operation *) { return failure(); });
}
@@ -1056,7 +1056,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
assert(producerUnPackOp && "expect a valid UnPackOp");
- if (!controlFn(producerUnPackOp, genericOp))
+ if (!controlFn(unPackedOperand))
return failure();
auto packInfo =
@@ -1152,7 +1152,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
if (!unpackOp)
return failure();
- if (!controlFn(unpackOp, padOp))
+ if (!controlFn(&padOp.getSourceMutable()))
return failure();
Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 55d0715bbb9bb..4cf2460150d14 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,8 +33,7 @@ struct TestDataLayoutPropagationPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
- patterns,
- [](Operation *producer, Operation *consumer) { return true; });
+ patterns, [](OpOperand *opOperand) { return true; });
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list