[Mlir-commits] [mlir] make transform.split_handle accept any handle kind (PR #118752)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Dec 4 23:19:51 PST 2024


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/118752

It can now split value and parameter handles in addition to operation handles. This is a generally useful functionality.

>From ae91a4c8d82c8ef3f625df53ebc77e1efb0f8406 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Wed, 4 Dec 2024 16:24:36 +0100
Subject: [PATCH] make transform.split_handle accept any handle kind

It can now split value and parameter handles in addition to operation handles.
This is a generally useful functionality.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td | 21 +++---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 64 +++++++++++++----
 .../Dialect/Transform/test-interpreter.mlir   | 69 ++++++++++++++++++-
 3 files changed, 130 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b946fc8875860b..2d71d5b0892afe 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1062,36 +1062,37 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
     [FunctionalStyleTransformOpTrait,
      DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let summary = "Splits a handle of payload ops into handles with a single op";
+  let summary = "Splits a handle or parameter into multiple values";
   let description = [{
     Splits `handle` into one or multiple handles, as specified by the number
     of results of this operation. `handle` should be mapped to as many payload
-    ops as there are results. Otherwise, this transform will fail produces a
-    silenceable failure by default. Each result handle is mapped to exactly one
-    payload op. The order of the payload ops is preserved, i.e., the i-th
-    payload op is mapped to the i-th result handle.
+    ops, values or parameteres as there are results. Otherwise, this transform 
+    will fail producing a silenceable failure by default. Each result handle 
+    is mapped to exactly one payload unless specified otherwise by attributes
+    described below. The order of the payloads is preserved,  i.e., the i-th 
+    payload is mapped to the i-th result handle.
 
     This operation is useful for ensuring a statically known number of
-    operations are tracked by the source `handle` and to extract them into
+    payloads are tracked by the source `handle` and to extract them into
     individual handles that can be further manipulated in isolation.
 
-    If there are more payload ops than results, the remaining ops are mapped to
+    If there are more payloads than results, the remaining payloads are mapped to
     the result with index `overflow_result`. If no `overflow_result` is
     specified, the transform produces a silenceable failure.
 
     If there are fewer payload ops than results, the transform produces a
     silenceable failure if `fail_on_payload_too_small` is set to "true".
     Otherwise, it succeeds and the remaining result handles are not mapped to
-    any op. It also succeeds if `handle` is empty and
+    anything. It also succeeds if `handle` is empty and
     `pass_through_empty_handle` is set to "true", regardless of
     `fail_on_payload_too_small`.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$handle,
+  let arguments = (ins Transform_AnyHandleOrParamType:$handle,
                        DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
                        DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
                        OptionalAttr<I64Attr>:$overflow_result);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let hasVerifier = 1;
 
   let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 590cae9aa0d667..68d1f2aef638a5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2415,32 +2415,63 @@ DiagnosedSilenceableFailure
 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformResults &results,
                                 transform::TransformState &state) {
-  int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
+  int64_t numPayloads =
+      llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
+          .Case<TransformHandleTypeInterface>([&](auto x) {
+            return llvm::range_size(state.getPayloadOps(getHandle()));
+          })
+          .Case<TransformValueHandleTypeInterface>([&](auto x) {
+            return llvm::range_size(state.getPayloadValues(getHandle()));
+          })
+          .Case<TransformParamTypeInterface>([&](auto x) {
+            return llvm::range_size(state.getParams(getHandle()));
+          })
+          .Default([](auto x) {
+            llvm_unreachable("unknown transform dialect type interface");
+            return -1;
+          });
+
   auto produceNumOpsError = [&]() {
     return emitSilenceableError()
            << getHandle() << " expected to contain " << this->getNumResults()
-           << " payload ops but it contains " << numPayloadOps
-           << " payload ops";
+           << " payloads but it contains " << numPayloads
+           << " payloads";
   };
 
   // Fail if there are more payload ops than results and no overflow result was
   // specified.
-  if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
+  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
     return produceNumOpsError();
 
   // Fail if there are more results than payload ops. Unless:
   // - "fail_on_payload_too_small" is set to "false", or
   // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
-  if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
-      (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
+  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
+      (numPayloads != 0 || !getPassThroughEmptyHandle()))
     return produceNumOpsError();
 
-  // Distribute payload ops.
-  SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
+  // Distribute payloads.
+  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
   if (getOverflowResult())
-    resultHandles[*getOverflowResult()].reserve(numPayloadOps -
-                                                getNumResults());
-  for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
+    resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
+
+  auto container = [&]() {
+    if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
+      return llvm::map_to_vector(
+          state.getPayloadOps(getHandle()),
+          [](Operation *op) -> MappedValue { return op; });
+    }
+    if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
+      return llvm::map_to_vector(state.getPayloadValues(getHandle()),
+                                 [](Value v) -> MappedValue { return v; });
+    }
+    assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
+           "unsupported kind of transform dialect type");
+    return llvm::map_to_vector(state.getParams(getHandle()),
+                               [](Attribute a) -> MappedValue { return a; });
+  }();
+
+  for (auto &&en : llvm::enumerate(container)) {
     int64_t resultNum = en.index();
     if (resultNum >= getNumResults())
       resultNum = *getOverflowResult();
@@ -2449,7 +2480,7 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
 
   // Set transform op results.
   for (auto &&it : llvm::enumerate(resultHandles))
-    results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
+    results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())), it.value());
 
   return DiagnosedSilenceableFailure::success();
 }
@@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() {
   if (getOverflowResult().has_value() &&
       !(*getOverflowResult() < getNumResults()))
     return emitOpError("overflow_result is not a valid result index");
+
+  for (Type resultType : getResultTypes()) {
+    if (implementSameTransformInterface(getHandle().getType(), resultType))
+      continue;
+
+    return emitOpError("expects result types to implement the same transform "
+                       "interface as the operand type");
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 4fe2dbedff56e3..ecc234587cda95 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1094,7 +1094,7 @@ module attributes {transform.with_named_sequence} {
     // expected-remark @below {{1}}
     transform.debug.emit_param_as_remark  %p : !transform.param<i64>
     %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+    // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
     %h_2:3 = transform.split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
@@ -1180,6 +1180,71 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func private @opaque() -> (i32, i32)
+
+func.func @split_handle() {
+  func.call @opaque() : () -> (i32, i32)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %val : (!transform.any_value) -> !transform.any_param
+    // expected-remark @below {{total 2}}
+    transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
+    %h:2 = transform.split_handle %val : (!transform.any_value) -> (!transform.any_value, !transform.any_value)
+    %p1 = transform.num_associations %h#0 : (!transform.any_value) -> !transform.any_param
+    %p2 = transform.num_associations %h#1 : (!transform.any_value) -> !transform.any_param
+    // expected-remark @below {{first 1}}
+    transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
+    // expected-remark @below {{second 1}}
+    transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @opaque() -> (i32, i32)
+
+func.func @split_handle() {
+  func.call @opaque() : () -> (i32, i32)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
+    %type = transform.get_type %val : (!transform.any_value) -> !transform.any_param
+    %p = transform.num_associations %type : (!transform.any_param) -> !transform.any_param
+    // expected-remark @below {{total 2}}
+    transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
+    %h:2 = transform.split_handle %type : (!transform.any_param) -> (!transform.any_param, !transform.any_param)
+    %p1 = transform.num_associations %h#0 : (!transform.any_param) -> !transform.any_param
+    %p2 = transform.num_associations %h#1 : (!transform.any_param) -> !transform.any_param
+    // expected-remark @below {{first 1}}
+    transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
+    // expected-remark @below {{second 1}}
+    transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    // expected-error @below {{op expects result types to implement the same transform interface as the operand type}}
+    transform.split_handle %fun : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
+    transform.yield
+  }
+}
+
+// -----
+
 "test.some_op"() : () -> ()
 "other_dialect.other_op"() : () -> ()
 
@@ -1324,7 +1389,7 @@ module attributes {transform.with_named_sequence} {
     transform.sequence %root : !transform.any_op -> !transform.any_op failures(propagate) {
     ^bb1(%fun: !transform.any_op):
       %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
-      // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+      // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
       %h_2:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
       /// Test that yield does not crash in the presence of silenceable error in
       /// propagate mode.



More information about the Mlir-commits mailing list