[Mlir-commits] [mlir] 5310be5 - [mlir] make `fuse_into_containing_op` preserve the containing op handle

Alex Zinenko llvmlistbot at llvm.org
Fri May 26 09:01:48 PDT 2023


Author: Alex Zinenko
Date: 2023-05-26T16:01:40Z
New Revision: 5310be521db2aa8c05a1c1adb7e108fc2c7c9ddc

URL: https://github.com/llvm/llvm-project/commit/5310be521db2aa8c05a1c1adb7e108fc2c7c9ddc
DIFF: https://github.com/llvm/llvm-project/commit/5310be521db2aa8c05a1c1adb7e108fc2c7c9ddc.diff

LOG: [mlir] make `fuse_into_containing_op` preserve the containing op handle

This partially undoes the intent of https://reviews.llvm.org/D151418 by
cheating its way to keep the "containing op" (aka loop) handle read-only
in fusion. It is crucial to do so for composability of tiling and
fusion. Specfically, after the "containing op" handle started being
consumed, it became impossible to perform additional tiling after fusion
except tiling the last-fused op:

  %tiled1, %loop1 = tile %op
  %producer1, %loop2 = fuse %producer into %loop1
  // invalid, because %tiled1 is invalidated by consuming %loop1
  // that points to its parent
  tile %tiled1

or

  %tiled1, %loop1 = tile %op
  %tiled2, %loop2 = tile %tiled1
  %p2 = fuse %producer into %loop1
  // invalid, because %loop2 is invalidated by consuming %loop1
  // that points to its parent
  fuse %p2 into %loop2

The approach here makes creative use of the state extension mechanism to
update the payload operation associted with the operand handle. Further
investigation is necessary to understand if is consistent with the
overall execution model of the transform dialect, but it is crucial to
restore composability ASAP.

Reviewed By: springerm, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D151555

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 52699db910461..f18f24d4c3d9c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -663,6 +664,36 @@ bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
   return true;
 }
 
+namespace {
+/// Unsafely exposes an internal protected method of TransformState::Extension
+/// as public.
+///
+/// MUST NOT be used directly.
+class UnsafeOpReplacementStateExtension : public TransformState::Extension {
+public:
+  UnsafeOpReplacementStateExtension(TransformState &state)
+      : TransformState::Extension(state) {}
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      UnsafeOpReplacementStateExtension)
+
+  LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) {
+    return replacePayloadOp(op, replacement);
+  }
+};
+} // namespace
+
+/// Replaces `payload` with `replacement` in all handles stored in the state.
+/// MUST NOT be used except for the case immediately below.
+static void forciblyReplaceReferencedPayloadOperation(TransformState &state,
+                                                      Operation *payload,
+                                                      Operation *replacement) {
+  UnsafeOpReplacementStateExtension extension(state);
+  // This may return failure if the payload is not associated with any handle,
+  // ignore that.
+  (void)extension.doReplacePayloadOp(payload, replacement);
+}
+
 DiagnosedSilenceableFailure
 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
                                        transform::TransformState &state) {
@@ -757,6 +788,14 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
+  // Update handles associated with the containing op so we don't need to
+  // invalidate them. This is a hack to support better composability between
+  // tiling and fusion while a proper mechanism is being investigated.
+  //
+  // DO NOT replicate this elsewhere unless you understand what you are doing.
+  forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(),
+                                            containingOp);
+
   results.set(cast<OpResult>(getFusedOp()), fusedOps);
   results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
   return DiagnosedSilenceableFailure::success();
@@ -765,7 +804,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 void transform::FuseIntoContainingOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getProducerOp(), effects);
-  consumesHandle(getContainingOp(), effects);
+  onlyReadsHandle(getContainingOp(), effects);
   producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index dd850087ba7f4..8d1c2804609ff 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -35,3 +35,15 @@ transform.sequence failures(propagate) {
   // CHECK: transform.structured.scalarize
   %0 = transform.structured.scalarize %arg0 : (!transform.any_op) -> !transform.any_op
 }
+
+// Check that the second argument of `fuse_into_containing_op` is not consumed
+// (if it had been, we would have seen a diagnostic about multiple consumers).
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  %loop = transform.structured.match ops{["scf.forall"]} in %arg0
+    : (!transform.any_op) -> !transform.any_op
+  %0:2 = transform.structured.fuse_into_containing_op %arg1 into %loop
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %1:2 = transform.structured.fuse_into_containing_op %arg2 into %loop
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+}


        


More information about the Mlir-commits mailing list