[Mlir-commits] [mlir] 04fc471 - [mlir][linalg] Switch to use OpOperand* in ControlPropagationFn. (#96697)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 8 09:53:13 PDT 2024
Author: Han-Chung Wang
Date: 2024-07-08T09:53:09-07:00
New Revision: 04fc471f485a9beadd8ccc63f6af29765ec6f45b
URL: https://github.com/llvm/llvm-project/commit/04fc471f485a9beadd8ccc63f6af29765ec6f45b
DIFF: https://github.com/llvm/llvm-project/commit/04fc471f485a9beadd8ccc63f6af29765ec6f45b.diff
LOG: [mlir][linalg] Switch to use OpOperand* in ControlPropagationFn. (#96697)
It's not easy to determine whether we want to propagate pack/unpack ops
because we don't know the (producer, consumer) information. The
revisions switch it to `OpOperand*`, so the control function can capture
the (producer, consumer) pair. E.g.,
```
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
```
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2a58d02d7b704..693fca4f63502 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1652,7 +1652,7 @@ void populateElementwiseOpsFusionPatterns(
/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
-using ControlPropagationFn = std::function<bool(Operation *op)>;
+using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
/// Patterns to bubble up or down data layout ops across other operations.
void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..0d7ab7232e1e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
return failure();
// User controlled propagation function.
- if (!controlFn(genericOp))
+ if (!controlFn(&packOp.getSourceMutable()))
return failure();
// TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
return failure();
// User controlled propagation function.
- if (!controlFn(padOp))
+ if (!controlFn(&packOp.getSourceMutable()))
return failure();
if (!padOp.getResult().hasOneUse())
@@ -844,7 +844,7 @@ class BubbleUpPackOpThroughReshapeOp final
}
// User controlled propagation function.
- if (!controlFn(srcOp))
+ if (!controlFn(&packOp.getSourceMutable()))
return failure();
return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -880,10 +880,13 @@ class BubbleUpPackOpThroughReshapeOp final
/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
-static LogicalResult
-pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
- tensor::ExpandShapeOp expandOp,
- PatternRewriter &rewriter) {
+static LogicalResult pushDownUnPackOpThroughExpandShape(
+ tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+ // User controlled propagation function.
+ if (!controlFn(&expandOp.getSrcMutable()))
+ return failure();
+
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -970,13 +973,10 @@ class PushDownUnPackOpThroughReshapeOp final
}
Operation *consumerOp = *result.user_begin();
- // User controlled propagation function.
- if (!controlFn(consumerOp))
- return failure();
-
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
.Case([&](tensor::ExpandShapeOp op) {
- return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+ return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
+ controlFn);
})
.Default([](Operation *) { return failure(); });
}
@@ -1038,7 +1038,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
/// inner_dims_pos = [3] inner_tiles = [32] into %0
///
static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ ControlPropagationFn controlFn) {
if (genericOp.getNumResults() != 1)
return failure();
@@ -1055,6 +1056,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
tensor::UnPackOp producerUnPackOp =
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
assert(producerUnPackOp && "expect a valid UnPackOp");
+
+ if (!controlFn(unPackedOperand))
+ return failure();
+
auto packInfo =
getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
if (failed(packInfo))
@@ -1122,10 +1127,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!controlFn(genericOp))
- return failure();
-
- auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
+ auto genericAndRepl =
+ pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1150,7 +1153,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
if (!unpackOp)
return failure();
- if (!controlFn(padOp))
+ if (!controlFn(&padOp.getSourceMutable()))
return failure();
Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b5998d9c851e4..4cf2460150d14 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,7 @@ struct TestDataLayoutPropagationPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
- patterns, [](Operation *op) { return true; });
+ patterns, [](OpOperand *opOperand) { return true; });
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list