[Mlir-commits] [mlir] [mlir][linalg] Add producer and consumer info to ControlPropagationFn. (PR #96697)

Han-Chung Wang llvmlistbot at llvm.org
Tue Jun 25 13:56:31 PDT 2024


https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/96697

It's not easy to determine whether we want to propagate pack/unpack ops because we don't know the (producer, consumer) infomation. Exposes both operations to the control function helps a lot to make the decision in downstream projects.

>From a6f4f768965237739ffe692ef958cd27777c3bb7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 25 Jun 2024 13:52:17 -0700
Subject: [PATCH] [mlir][linalg] Add producer and consumer info to
 ControlPropagationFn.

It's not easy to determine whether we want to propagate pack/unpack ops
because we don't know the (producer, consumer) infomation. Exposes both
operations to the control function helps a lot to make the decision in
downstream projects.
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 ++-
 .../Transforms/DataLayoutPropagation.cpp      | 23 +++++++++++--------
 .../Linalg/TestDataLayoutPropagation.cpp      |  3 ++-
 3 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b0871a5dff5da..e14e1b988ac28 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1630,7 +1630,8 @@ void populateElementwiseOpsFusionPatterns(
 
 /// Function type which is used to control propagation of tensor.pack/unpack
 /// ops.
-using ControlPropagationFn = std::function<bool(Operation *op)>;
+using ControlPropagationFn =
+    std::function<bool(Operation *producer, Operation *consumer)>;
 
 /// Patterns to bubble up or down data layout ops across other operations.
 void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..f87529211dce6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
     return failure();
 
   // User controlled propagation function.
-  if (!controlFn(genericOp))
+  if (!controlFn(genericOp, packOp))
     return failure();
 
   // TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
       return failure();
 
     // User controlled propagation function.
-    if (!controlFn(padOp))
+    if (!controlFn(padOp, packOp))
       return failure();
 
     if (!padOp.getResult().hasOneUse())
@@ -843,7 +843,7 @@ class BubbleUpPackOpThroughReshapeOp final
     }
 
     // User controlled propagation function.
-    if (!controlFn(srcOp))
+    if (!controlFn(srcOp, packOp))
       return failure();
 
     return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -970,7 +970,7 @@ class PushDownUnPackOpThroughReshapeOp final
 
     Operation *consumerOp = *result.user_begin();
     // User controlled propagation function.
-    if (!controlFn(consumerOp))
+    if (!controlFn(unPackOp, consumerOp))
       return failure();
 
     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
@@ -1037,7 +1037,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
 ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
 ///
 static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+                                 ControlPropagationFn controlFn) {
   if (genericOp.getNumResults() != 1)
     return failure();
 
@@ -1054,6 +1055,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
   tensor::UnPackOp producerUnPackOp =
       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
   assert(producerUnPackOp && "expect a valid UnPackOp");
+
+  if (!controlFn(producerUnPackOp, genericOp))
+    return failure();
+
   auto packInfo =
       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
   if (failed(packInfo))
@@ -1121,10 +1126,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!controlFn(genericOp))
-      return failure();
-
-    auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
+    auto genericAndRepl =
+        pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
     if (failed(genericAndRepl))
       return failure();
     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1149,7 +1152,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
     if (!unpackOp)
       return failure();
 
-    if (!controlFn(padOp))
+    if (!controlFn(unpackOp, padOp))
       return failure();
 
     Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b5998d9c851e4..55d0715bbb9bb 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
-        patterns, [](Operation *op) { return true; });
+        patterns,
+        [](Operation *producer, Operation *consumer) { return true; });
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();



More information about the Mlir-commits mailing list