[Mlir-commits] [mlir] 04fc471 - [mlir][linalg] Switch to use OpOperand* in ControlPropagationFn. (#96697)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 8 09:53:13 PDT 2024


Author: Han-Chung Wang
Date: 2024-07-08T09:53:09-07:00
New Revision: 04fc471f485a9beadd8ccc63f6af29765ec6f45b

URL: https://github.com/llvm/llvm-project/commit/04fc471f485a9beadd8ccc63f6af29765ec6f45b
DIFF: https://github.com/llvm/llvm-project/commit/04fc471f485a9beadd8ccc63f6af29765ec6f45b.diff

LOG: [mlir][linalg] Switch to use OpOperand* in ControlPropagationFn. (#96697)

It's not easy to determine whether we want to propagate pack/unpack ops
because we don't know the (producer, consumer) information. The
revisions switch it to `OpOperand*`, so the control function can capture
the (producer, consumer) pair. E.g.,

```
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2a58d02d7b704..693fca4f63502 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1652,7 +1652,7 @@ void populateElementwiseOpsFusionPatterns(
 
 /// Function type which is used to control propagation of tensor.pack/unpack
 /// ops.
-using ControlPropagationFn = std::function<bool(Operation *op)>;
+using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
 
 /// Patterns to bubble up or down data layout ops across other operations.
 void populateDataLayoutPropagationPatterns(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..0d7ab7232e1e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
     return failure();
 
   // User controlled propagation function.
-  if (!controlFn(genericOp))
+  if (!controlFn(&packOp.getSourceMutable()))
     return failure();
 
   // TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
       return failure();
 
     // User controlled propagation function.
-    if (!controlFn(padOp))
+    if (!controlFn(&packOp.getSourceMutable()))
       return failure();
 
     if (!padOp.getResult().hasOneUse())
@@ -844,7 +844,7 @@ class BubbleUpPackOpThroughReshapeOp final
     }
 
     // User controlled propagation function.
-    if (!controlFn(srcOp))
+    if (!controlFn(&packOp.getSourceMutable()))
       return failure();
 
     return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -880,10 +880,13 @@ class BubbleUpPackOpThroughReshapeOp final
 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
 ///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
 ///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
-static LogicalResult
-pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
-                                   tensor::ExpandShapeOp expandOp,
-                                   PatternRewriter &rewriter) {
+static LogicalResult pushDownUnPackOpThroughExpandShape(
+    tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
+    PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+  // User controlled propagation function.
+  if (!controlFn(&expandOp.getSrcMutable()))
+    return failure();
+
   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -970,13 +973,10 @@ class PushDownUnPackOpThroughReshapeOp final
     }
 
     Operation *consumerOp = *result.user_begin();
-    // User controlled propagation function.
-    if (!controlFn(consumerOp))
-      return failure();
-
     return TypeSwitch<Operation *, LogicalResult>(consumerOp)
         .Case([&](tensor::ExpandShapeOp op) {
-          return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+          return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
+                                                    controlFn);
         })
         .Default([](Operation *) { return failure(); });
   }
@@ -1038,7 +1038,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
 ///                       inner_dims_pos = [3] inner_tiles = [32] into %0
 ///
 static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+                                 ControlPropagationFn controlFn) {
   if (genericOp.getNumResults() != 1)
     return failure();
 
@@ -1055,6 +1056,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
   tensor::UnPackOp producerUnPackOp =
       unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
   assert(producerUnPackOp && "expect a valid UnPackOp");
+
+  if (!controlFn(unPackedOperand))
+    return failure();
+
   auto packInfo =
       getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
   if (failed(packInfo))
@@ -1122,10 +1127,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!controlFn(genericOp))
-      return failure();
-
-    auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
+    auto genericAndRepl =
+        pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
     if (failed(genericAndRepl))
       return failure();
     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1150,7 +1153,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
     if (!unpackOp)
       return failure();
 
-    if (!controlFn(padOp))
+    if (!controlFn(&padOp.getSourceMutable()))
       return failure();
 
     Location loc = padOp.getLoc();

diff  --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b5998d9c851e4..4cf2460150d14 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,7 @@ struct TestDataLayoutPropagationPass
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
-        patterns, [](Operation *op) { return true; });
+        patterns, [](OpOperand *opOperand) { return true; });
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();


        


More information about the Mlir-commits mailing list