[Mlir-commits] [mlir] 00d1a1a - [mlir] Add ReplicateOp to the Transform dialect

Alex Zinenko llvmlistbot at llvm.org
Tue Jul 12 02:08:18 PDT 2022


Author: Alex Zinenko
Date: 2022-07-12T09:07:59Z
New Revision: 00d1a1a25fbbfd1fdc43f31194378423a707c738

URL: https://github.com/llvm/llvm-project/commit/00d1a1a25fbbfd1fdc43f31194378423a707c738
DIFF: https://github.com/llvm/llvm-project/commit/00d1a1a25fbbfd1fdc43f31194378423a707c738.diff

LOG: [mlir] Add ReplicateOp to the Transform dialect

This handle manipulation operation allows one to define a new handle that is
associated with a the same payload IR operations N times, where N can be driven
by the size of payload IR operation list associated with another handle. This
can be seen as a sort of broadcast that can be used to ensure the lists
associated with two handles have equal numbers of payload IR ops as expected by
many pairwise transform operations.

Introduce an additional "expensive" check that guards against consuming a
handle that is assocaited with the same payload IR operation more than once as
this is likely to lead to double-free or other undesired effects.

Depends On D129110

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129216

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/python/mlir/dialects/_transform_ops_ext.py
    mlir/test/Dialect/Transform/expensive-checks.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 5589d937e2ac0..f767361499673 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -845,6 +845,27 @@ transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
   return res;
 }
 } // namespace detail
+
+/// Populates `effects` with the memory effects indicating the operation on the
+/// given handle value:
+///   - consumes = Read + Free,
+///   - produces = Allocate + Write,
+///   - onlyReads = Read.
+void consumesHandle(ValueRange handles,
+                    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+void producesHandle(ValueRange handles,
+                    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+void onlyReadsHandle(ValueRange handles,
+                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+
+/// Checks whether the transform op consumes the given handle.
+bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
+
+/// Populates `effects` with the memory effects indicating the access to payload
+/// IR resource.
+void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ee86c888fb7be..8ee8c1ac79ba1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -174,6 +174,42 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
   let assemblyFormat = "$pattern_name `in` $root attr-dict";
 }
 
+def ReplicateOp : TransformDialectOp<"replicate",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Lists payload ops multiple times in the new handle";
+  let description = [{
+    Produces a new handle associated with a list of payload IR ops that is
+    computed by repeating the list of payload IR ops associated with the
+    operand handle as many times as the "pattern" handle has associated
+    operations. For example, if pattern is associated with [op1, op2] and the
+    operand handle is associated with [op3, op4, op5], the resulting handle
+    will be associated with [op3, op4, op5, op3, op4, op5].
+
+    This transformation is useful to "align" the sizes of payload IR lists
+    before a transformation that expects, e.g., identically-sized lists. For
+    example, a transformation may be parameterized by same notional per-target 
+    size computed at runtime and supplied as another handle, the replication
+    allows this size to be computed only once and used for every target instead
+    of replicating the computation itself.
+
+    Note that it is undesirable to pass a handle with duplicate operations to
+    an operation that consumes the handle. Handle consumption often indicates
+    that the associated payload IR ops are destroyed, so having the same op
+    listed more than once will lead to double-free. Single-operand
+    MergeHandlesOp may be used to deduplicate the associated list of payload IR
+    ops when necessary. Furthermore, a combination of ReplicateOp and
+    MergeHandlesOp can be used to construct arbitrary lists with repetitions.
+  }];
+
+  let arguments = (ins PDL_Operation:$pattern,
+                       Variadic<PDL_Operation>:$handles);
+  let results = (outs Variadic<PDL_Operation>:$replicated);
+  let assemblyFormat =
+    "`num` `(` $pattern `)` $handles "
+    "custom<PDLOpTypedResults>(type($replicated), ref($handles)) attr-dict";
+}
+
 def SequenceOp : TransformDialectOp<"sequence",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getSuccessorEntryOperands", "getSuccessorRegions",

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ecf1cbe8aa3ed..a5f9693252b67 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -55,7 +55,7 @@ Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
     Mappings &map, Operation *operation, Value handle) {
   auto insertionResult = map.reverse.insert({operation, handle});
-  if (!insertionResult.second) {
+  if (!insertionResult.second && insertionResult.first->second != handle) {
     InFlightDiagnostic diag = operation->emitError()
                               << "operation tracked by two handles";
     diag.attachNote(handle.getLoc()) << "handle";
@@ -191,9 +191,27 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
 DiagnosedSilenceableFailure
 transform::TransformState::applyTransform(TransformOpInterface transform) {
   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
-  if (options.getExpensiveChecksEnabled() &&
-      failed(checkAndRecordHandleInvalidation(transform))) {
-    return DiagnosedSilenceableFailure::definiteFailure();
+  if (options.getExpensiveChecksEnabled()) {
+    if (failed(checkAndRecordHandleInvalidation(transform)))
+      return DiagnosedSilenceableFailure::definiteFailure();
+
+    for (OpOperand &operand : transform->getOpOperands()) {
+      if (!isHandleConsumed(operand.get(), transform))
+        continue;
+
+      DenseSet<Operation *> seen;
+      for (Operation *op : getPayloadOps(operand.get())) {
+        if (!seen.insert(op).second) {
+          DiagnosedSilenceableFailure diag =
+              transform.emitSilenceableError()
+              << "a handle passed as operand #" << operand.getOperandNumber()
+              << " and consumed by this operation points to a payload "
+                 "operation more than once";
+          diag.attachNote(op->getLoc()) << "repeated target op";
+          return diag;
+        }
+      }
+    }
   }
 
   transform::TransformResults results(transform->getNumResults());
@@ -326,6 +344,70 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Memory effects.
+//===----------------------------------------------------------------------===//
+
+void transform::consumesHandle(
+    ValueRange handles,
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (Value handle : handles) {
+    effects.emplace_back(MemoryEffects::Read::get(), handle,
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Free::get(), handle,
+                         TransformMappingResource::get());
+  }
+}
+
+/// Returns `true` if the given list of effects instances contains an instance
+/// with the effect type specified as template parameter.
+template <typename EffectTy>
+static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
+  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
+    return isa<EffectTy>(effect.getEffect());
+  });
+}
+
+bool transform::isHandleConsumed(Value handle,
+                                 transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffectsOnValue(handle, effects);
+  return hasEffect<MemoryEffects::Read>(effects) &&
+         hasEffect<MemoryEffects::Free>(effects);
+}
+
+void transform::producesHandle(
+    ValueRange handles,
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (Value handle : handles) {
+    effects.emplace_back(MemoryEffects::Allocate::get(), handle,
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), handle,
+                         TransformMappingResource::get());
+  }
+}
+
+void transform::onlyReadsHandle(
+    ValueRange handles,
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (Value handle : handles) {
+    effects.emplace_back(MemoryEffects::Read::get(), handle,
+                         TransformMappingResource::get());
+  }
+}
+
+void transform::modifiesPayload(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+}
+
+void transform::onlyReadsPayload(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+}
+
 //===----------------------------------------------------------------------===//
 // Generated interface implementation.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 1ff58853136d6..42fd968b250d7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -23,6 +23,16 @@
 
 using namespace mlir;
 
+static ParseResult parsePDLOpTypedResults(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
+  types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
+  return success();
+}
+
+static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
+                                   ValueRange) {}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
@@ -354,6 +364,33 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ReplicateOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplicateOp::apply(transform::TransformResults &results,
+                              transform::TransformState &state) {
+  unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
+  for (const auto &en : llvm::enumerate(getHandles())) {
+    Value handle = en.value();
+    ArrayRef<Operation *> current = state.getPayloadOps(handle);
+    SmallVector<Operation *> payload;
+    payload.reserve(numRepetitions * current.size());
+    for (unsigned i = 0; i < numRepetitions; ++i)
+      llvm::append_range(payload, current);
+    results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ReplicateOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getPattern(), effects);
+  consumesHandle(getHandles(), effects);
+  producesHandle(getReplicated(), effects);
+}
+
 //===----------------------------------------------------------------------===//
 // SequenceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index ca45ab7e28176..e75d6b5f97660 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -59,6 +59,22 @@ def __init__(self,
         ip=ip)
 
 
+class ReplicateOp:
+
+  def __init__(self,
+               pattern: Union[Operation, Value],
+               handles: Sequence[Union[Operation, Value]],
+               *,
+               loc=None,
+               ip=None):
+    super().__init__(
+        [pdl.OperationType.get()] * len(handles),
+        _get_op_result_or_value(pattern),
+        [_get_op_result_or_value(h) for h in handles],
+        loc=loc,
+        ip=ip)
+
+
 class SequenceOp:
 
   @overload

diff  --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir
index c86367154f37b..2e49efc2ea50a 100644
--- a/mlir/test/Dialect/Transform/expensive-checks.mlir
+++ b/mlir/test/Dialect/Transform/expensive-checks.mlir
@@ -25,3 +25,37 @@ transform.with_pdl_patterns {
     test_print_remark_at_operand %0, "remark"
   }
 }
+
+// -----
+
+func.func @func1() {
+  // expected-note @below {{repeated target op}}
+  return
+}
+func.func private @func2()
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @func : benefit(1) {
+    %0 = operands
+    %1 = types
+    %2 = operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    rewrite %2 with "transform.dialect"
+  }
+  pdl.pattern @return : benefit(1) {
+    %0 = operands
+    %1 = types
+    %2 = operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    rewrite %2 with "transform.dialect"
+  }
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @func in %arg1
+    %1 = pdl_match @return in %arg1
+    %2 = replicate num(%0) %1
+    // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
+    test_consume_operand %2
+    test_print_remark_at_operand %0, "remark"
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a47021396e1a6..1f0dd40fc8b8b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -569,3 +569,31 @@ transform.with_pdl_patterns {
     transform.test_mixed_sucess_and_silenceable %0
   }
 }
+
+// -----
+
+module {
+  func.func private @foo()
+  func.func private @bar()
+
+  transform.with_pdl_patterns {
+  ^bb0(%arg0: !pdl.operation):
+    pdl.pattern @func : benefit(1) {
+      %0 = pdl.operands
+      %1 = pdl.types
+      %2 = pdl.operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+      pdl.rewrite %2 with "transform.dialect"
+    }
+
+    transform.sequence %arg0 {
+    ^bb0(%arg1: !pdl.operation):
+      %0 = pdl_match @func in %arg1
+      %1 = replicate num(%0) %arg1
+      // expected-remark @below {{2}}
+      test_print_number_of_associated_payload_ir_ops %1
+      %2 = replicate num(%0) %1
+      // expected-remark @below {{4}}
+      test_print_number_of_associated_payload_ir_ops %2
+    }
+  }
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 7930f3904b78b..5b7a0b88752ee 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -275,6 +275,18 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
   return emitDefaultSilenceableFailure(target);
 }
 
+DiagnosedSilenceableFailure
+mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  emitRemark() << state.getPayloadOps(getHandle()).size();
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getHandle(), effects);
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4c144e7be9455..83b5700ff0741 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -212,4 +212,13 @@ def TestMixedSuccessAndSilenceableOp
   }];
 }
 
+def TestPrintNumberOfAssociatedPayloadIROps
+  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins PDL_Operation:$handle);
+  let assemblyFormat = "$handle attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 21392ca7d5e32..bb3afb15ffb68 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -94,3 +94,19 @@ def testMergeHandlesOp():
   # CHECK: transform.sequence
   # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
   # CHECK:   = merge_handles %[[ARG1]]
+
+
+ at run
+def testReplicateOp():
+  with_pdl = transform.WithPDLPatternsOp()
+  with InsertionPoint(with_pdl.body):
+    sequence = transform.SequenceOp(with_pdl.bodyTarget)
+    with InsertionPoint(sequence.body):
+      m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
+      m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
+      transform.ReplicateOp(m1, [m2])
+      transform.YieldOp()
+  # CHECK-LABEL: TEST: testReplicateOp
+  # CHECK: %[[FIRST:.+]] = pdl_match
+  # CHECK: %[[SECOND:.+]] = pdl_match
+  # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]


        


More information about the Mlir-commits mailing list