[Mlir-commits] [mlir] [mlir][linalg] Add producer and consumer info to ControlPropagationFn. (PR #96697)

Han-Chung Wang llvmlistbot at llvm.org
Mon Jul 1 13:56:37 PDT 2024


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/96697

>From a6f4f768965237739ffe692ef958cd27777c3bb7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 25 Jun 2024 13:52:17 -0700
Subject: [PATCH 1/2] [mlir][linalg] Add producer and consumer info to
 ControlPropagationFn.

It's not easy to determine whether we want to propagate pack/unpack ops
because we don't know the (producer, consumer) infomation. Exposes both
operations to the control function helps a lot to make the decision in
downstream projects.
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 ++-
 .../Transforms/DataLayoutPropagation.cpp      | 23 +++++++++++--------
 .../Linalg/TestDataLayoutPropagation.cpp      |  3 ++-
 3 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b0871a5dff5da..e14e1b988ac28 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1630,7 +1630,8 @@ void populateElementwiseOpsFusionPatterns(
 
 /// Function type which is used to control propagation of tensor.pack/unpack
 /// ops.
-using ControlPropagationFn = std::function<bool(Operation *op)>;
+using ControlPropagationFn =
+    std::function<bool(Operation *producer, Operation *consumer)>;
 
 /// Patterns to bubble up or down data layout ops across other operations.
 void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..f87529211dce6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
     return failure();
 
   // User controlled propagation function.
-  if (!controlFn(genericOp))
+  if (!controlFn(genericOp, packOp))
     return failure();
 
   // TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
       return failure();
 
     // User controlled propagation function.
-    if (!controlFn(padOp))
+    if (!controlFn(padOp, packOp))
       return failure();
 
     if (!padOp.getResult().hasOneUse())
@@ -843,7 +843,7 @@ class BubbleUpPackOpThroughReshapeOp final
     }
 
     // User controlled propagation function.
-    if (!controlFn(srcOp))
+    if (!controlFn(srcOp, packOp))
       return failure();
 
     return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -970,7 +970,7 @@ class PushDownUnPackOpThroughReshapeOp final
 
     Operation *consumerOp = *result.user_begin();
     // User controlled propagation function.
-    if (!controlFn(consumerOp))
+    if (!controlFn(unPackOp, consumerOp))
       return failure();
 
     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
@@ -1037,7 +1037,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
 ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
 ///
 static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+                                 ControlPropagationFn controlFn) {
   if (genericOp.getNumResults() != 1)
     return failure();
 
@@ -1054,6 +1055,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
   tensor::UnPackOp producerUnPackOp =
       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
   assert(producerUnPackOp && "expect a valid UnPackOp");
+
+  if (!controlFn(producerUnPackOp, genericOp))
+    return failure();
+
   auto packInfo =
       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
   if (failed(packInfo))
@@ -1121,10 +1126,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!controlFn(genericOp))
-      return failure();
-
-    auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
+    auto genericAndRepl =
+        pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
     if (failed(genericAndRepl))
       return failure();
     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1149,7 +1152,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
     if (!unpackOp)
       return failure();
 
-    if (!controlFn(padOp))
+    if (!controlFn(unpackOp, padOp))
       return failure();
 
     Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b5998d9c851e4..55d0715bbb9bb 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
-        patterns, [](Operation *op) { return true; });
+        patterns,
+        [](Operation *producer, Operation *consumer) { return true; });
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();

>From cf92e50d2c802b66f0d27a682c56f9bc5d8644ef Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 1 Jul 2024 13:56:09 -0700
Subject: [PATCH 2/2] Switch to use `OpOperand*`

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 +-
 .../Transforms/DataLayoutPropagation.cpp      | 28 +++++++++----------
 .../Linalg/TestDataLayoutPropagation.cpp      |  3 +-
 3 files changed, 16 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e14e1b988ac28..e080afaa3d66a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1630,8 +1630,7 @@ void populateElementwiseOpsFusionPatterns(
 
 /// Function type which is used to control propagation of tensor.pack/unpack
 /// ops.
-using ControlPropagationFn =
-    std::function<bool(Operation *producer, Operation *consumer)>;
+using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
 
 /// Patterns to bubble up or down data layout ops across other operations.
 void populateDataLayoutPropagationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f87529211dce6..5ae85e6ae2e6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
     return failure();
 
   // User controlled propagation function.
-  if (!controlFn(genericOp, packOp))
+  if (!controlFn(&packOp.getSourceMutable()))
     return failure();
 
   // TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
       return failure();
 
     // User controlled propagation function.
-    if (!controlFn(padOp, packOp))
+    if (!controlFn(&packOp.getSourceMutable()))
       return failure();
 
     if (!padOp.getResult().hasOneUse())
@@ -843,7 +843,7 @@ class BubbleUpPackOpThroughReshapeOp final
     }
 
     // User controlled propagation function.
-    if (!controlFn(srcOp, packOp))
+    if (!controlFn(&packOp.getSourceMutable()))
       return failure();
 
     return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -879,10 +879,13 @@ class BubbleUpPackOpThroughReshapeOp final
 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
 ///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
 ///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
-static LogicalResult
-pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
-                                   tensor::ExpandShapeOp expandOp,
-                                   PatternRewriter &rewriter) {
+static LogicalResult pushDownUnPackOpThroughExpandShape(
+    tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
+    PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+  // User controlled propagation function.
+  if (!controlFn(expandOp.getSrcMutable()))
+    return failure();
+
   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -969,13 +972,10 @@ class PushDownUnPackOpThroughReshapeOp final
     }
 
     Operation *consumerOp = *result.user_begin();
-    // User controlled propagation function.
-    if (!controlFn(unPackOp, consumerOp))
-      return failure();
-
     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
         .Case([&](tensor::ExpandShapeOp op) {
-          return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+          return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
+                                                    controlFn);
         })
         .Default([](Operation *) { return failure(); });
   }
@@ -1056,7 +1056,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
   assert(producerUnPackOp && "expect a valid UnPackOp");
 
-  if (!controlFn(producerUnPackOp, genericOp))
+  if (!controlFn(unPackedOperand))
     return failure();
 
   auto packInfo =
@@ -1152,7 +1152,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
     if (!unpackOp)
       return failure();
 
-    if (!controlFn(unpackOp, padOp))
+    if (!controlFn(&padOp.getSourceMutable()))
       return failure();
 
     Location loc = padOp.getLoc();
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 55d0715bbb9bb..4cf2460150d14 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,8 +33,7 @@ struct TestDataLayoutPropagationPass
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
-        patterns,
-        [](Operation *producer, Operation *consumer) { return true; });
+        patterns, [](OpOperand *opOperand) { return true; });
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();



More information about the Mlir-commits mailing list