[Mlir-commits] [mlir] 288529e - [mlir][transform] Clean up SplitHandlesOp

Matthias Springer llvmlistbot at llvm.org
Fri May 5 06:30:09 PDT 2023


Author: Matthias Springer
Date: 2023-05-05T22:29:43+09:00
New Revision: 288529e730f8f225c96ec2f816cef94593bb61a8

URL: https://github.com/llvm/llvm-project/commit/288529e730f8f225c96ec2f816cef94593bb61a8
DIFF: https://github.com/llvm/llvm-project/commit/288529e730f8f225c96ec2f816cef94593bb61a8.diff

LOG: [mlir][transform] Clean up SplitHandlesOp

* Rename to `SplitHandleOp`: it splits a single handle.
* Drop `num_result_handles` attribute: it is redundant and can be inferred from the number of results.
* Improve documentation and minor code cleanups.

Differential Revision: https://reviews.llvm.org/D149937

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 4bee1d4c6eb2e..53c1c0af54c0f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -509,26 +509,28 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
   }];
 }
 
-def SplitHandlesOp : TransformDialectOp<"split_handles",
+def SplitHandleOp : TransformDialectOp<"split_handle",
     [FunctionalStyleTransformOpTrait,
      DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let summary = "Splits handles from a union of payload ops to a list";
+  let summary = "Splits a handle of payload ops into handles with a single op";
   let description = [{
-    Creates `num_result_handles` transform IR handles extracted from the
-    `handle` operand. The resulting Payload IR operation handles are listed
-    in the same order as the operations appear in the source `handle`.
-    This is useful for ensuring a statically known number of operations are
-    tracked by the source `handle` and to extract them into individual handles
-    that can be further manipulated in isolation.
-
-    This operation succeeds and returns `num_result_handles` if the statically
-    specified `num_result_handles` corresponds to the dynamic number of
-    operations contained in the source `handle`. Otherwise it silently fails.
+    Splits `handle` into one or multiple handles, as specified by the number
+    of results of this operation. `handle` should be mapped to as many payload
+    ops as there are results. Otherwise, this transform will fail silently.
+    Each result handle is mapped to exactly one payload op. The order
+    of the payload ops is preserved, i.e., the i-th payload op is mapped to the
+    i-th result handle.
+
+    This operation is useful for ensuring a statically known number of
+    operations are tracked by the source `handle` and to extract them into
+    individual handles that can be further manipulated in isolation.
+
+    If `handle` is empty, this transform will succeed and all result handles
+    are empty.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$handle,
-                       I64Attr:$num_result_handles);
+  let arguments = (ins TransformHandleTypeInterface:$handle);
   let results = (outs Variadic<TransformHandleTypeInterface>:$results);
 
   let builders = [
@@ -536,8 +538,7 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
   ];
 
   let assemblyFormat = [{
-    $handle `in` `[` $num_result_handles `]`
-    attr-dict `:` functional-type(operands, results)
+    $handle attr-dict `:` functional-type(operands, results)
   }];
 }
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 138f4ff69cfbc..5c39ccc3ea510 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1488,48 +1488,43 @@ LogicalResult transform::NamedSequenceOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// SplitHandlesOp
+// SplitHandleOp
 //===----------------------------------------------------------------------===//
 
-void transform::SplitHandlesOp::build(OpBuilder &builder,
-                                      OperationState &result, Value target,
-                                      int64_t numResultHandles) {
+void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
+                                     Value target, int64_t numResultHandles) {
   result.addOperands(target);
-  result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name),
-                      builder.getI64IntegerAttr(numResultHandles));
   auto pdlOpType = pdl::OperationType::get(builder.getContext());
   result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
 }
 
 DiagnosedSilenceableFailure
-transform::SplitHandlesOp::apply(transform::TransformResults &results,
-                                 transform::TransformState &state) {
-  int64_t numResultHandles =
-      getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
-  int64_t expectedNumResultHandles = getNumResultHandles();
-  if (numResultHandles != expectedNumResultHandles) {
-    // Empty input handle corner case: always propagates empty handles in both
-    // suppress and propagate modes.
-    if (numResultHandles == 0) {
-      for (OpResult result : getResults())
-        results.set(result, {});
-      return DiagnosedSilenceableFailure::success();
-    }
+transform::SplitHandleOp::apply(transform::TransformResults &results,
+                                transform::TransformState &state) {
+  int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
 
-    // If the input handle was not empty and the number of result handles does
-    // not match, this is a legit silenceable error.
-    return emitSilenceableError()
-           << getHandle() << " expected to contain " << expectedNumResultHandles
-           << " operation handles but it contains " << numResultHandles
-           << " handles";
+  // Empty handle corner case: all result handles are empty.
+  if (numPayloadOps == 0) {
+    for (OpResult result : getResults())
+      results.set(result, {});
+    return DiagnosedSilenceableFailure::success();
   }
-  // Normal successful case.
+
+  // If the input handle was not empty and the number of payload ops does not
+  // match, this is a legit silenceable error.
+  if (numPayloadOps != getNumResults())
+    return emitSilenceableError()
+           << getHandle() << " expected to contain " << getNumResults()
+           << " payload ops but it contains " << numPayloadOps
+           << " payload ops";
+
   for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle())))
     results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+
   return DiagnosedSilenceableFailure::success();
 }
 
-void transform::SplitHandlesOp::getEffects(
+void transform::SplitHandleOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   onlyReadsHandle(getHandle(), effects);
   producesHandle(getResults(), effects);

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 3668abb1c3a42..9ed99577fbd0b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -818,7 +818,7 @@ transform.sequence failures(propagate) {
 
 // -----
 
-func.func @split_handles(%a: index, %b: index, %c: index) {
+func.func @split_handle(%a: index, %b: index, %c: index) {
   %0 = arith.muli %a, %b : index
   %1 = arith.muli %a, %c : index
   return
@@ -827,17 +827,17 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
 transform.sequence failures(propagate) {
 ^bb1(%fun: !pdl.operation):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
-  %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   // expected-remark @below {{1}}
   transform.test_print_number_of_associated_payload_ir_ops %h#0
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}}
-  %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+  // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+  %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // -----
 
-func.func @split_handles(%a: index, %b: index, %c: index) {
+func.func @split_handle(%a: index, %b: index, %c: index) {
   %0 = arith.muli %a, %b : index
   %1 = arith.muli %a, %c : index
   return
@@ -846,12 +846,12 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
 transform.sequence failures(suppress) {
 ^bb1(%fun: !pdl.operation):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
-  %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   // expected-remark @below {{1}}
   transform.test_print_number_of_associated_payload_ir_ops %h#0
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
   // Silenceable failure and all handles are now empty.
-  %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+  %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
   // expected-remark @below {{0}}
   transform.test_print_number_of_associated_payload_ir_ops %h_2#0
 }
@@ -966,7 +966,7 @@ transform.with_pdl_patterns {
 
 // -----
 
-func.func @split_handles(%a: index, %b: index, %c: index) {
+func.func @split_handle(%a: index, %b: index, %c: index) {
   %0 = arith.muli %a, %b : index
   %1 = arith.muli %a, %c : index
   return
@@ -975,8 +975,8 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
 transform.sequence -> !pdl.operation failures(propagate) {
 ^bb1(%fun: !pdl.operation):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}}
-  %h_2:3 = split_handles %muli in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+  // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+  %h_2:3 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
   /// Test that yield does not crash in the presence of silenceable error in
   /// propagate mode.
   yield %fun : !pdl.operation
@@ -988,7 +988,7 @@ transform.sequence -> !transform.any_op failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   %muli = transform.structured.match ops{["arith.muli"]} in %arg0 : (!transform.any_op) -> !transform.any_op
   // Edge case propagating empty handles in splitting.
-  %0:3 = split_handles %muli in [3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+  %0:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
   // Test does not crash when accessing the empty handle.
   yield %0#0 : !transform.any_op
 }


        


More information about the Mlir-commits mailing list