[Mlir-commits] [mlir] [mlir][Transforms] Add a utility method to move value definitions. (PR #130874)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 11 18:32:00 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

https://github.com/llvm/llvm-project/commit/205c5325b3c771d94feb0ec07e8ad89d27c2b29e added a transform utility that moved all SSA dependences of an operation before an insertion point. Similar to that, this PR adds a transform utility function, `moveValueDefinitions` to move the slice of operations that define all values in a `ValueRange` before the insertion point. While very similar to `moveOperationDependencies`, this method differs in a few ways

1. When computing the backward slice since the start of the slice is value, the slice computed needs to be inclusive.
2. The combined backward slice needs to be sorted topologically before moving them to avoid SSA use-def violations while moving individual ops.

The PR also adds a new transform op to test this new utility function.

---
Full diff: https://github.com/llvm/llvm-project/pull/130874.diff


5 Files Affected:

- (modified) mlir/include/mlir/Transforms/RegionUtils.h (+11) 
- (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+70-1) 
- (modified) mlir/test/Transforms/move-operation-deps.mlir (+226) 
- (modified) mlir/test/lib/Transforms/TestTransformsOps.cpp (+18) 
- (modified) mlir/test/lib/Transforms/TestTransformsOps.td (+22) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index e6b928d8ebecc..2ed96afbace81 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Region.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 
 #include "llvm/ADT/SetVector.h"
 
@@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
 LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
                                         Operation *insertionPoint);
 
+/// Move definitions of `values` before an insertion point. Current support is
+/// only for movement of definitions within the same basic block. Note that this
+/// is an all-or-nothing approach. Either definitions of all values are moved
+/// before insertion point, or none of them are.
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+                                   Operation *insertionPoint,
+                                   DominanceInfo &dominance);
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+                                   Operation *insertionPoint);
+
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index da0d486f0fdcb..6987a13b309d7 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
   // in different basic blocks.
   if (op->getBlock() != insertionPoint->getBlock()) {
     return rewriter.notifyMatchFailure(
-        op, "unsupported caes where operation and insertion point are not in "
+        op, "unsupported case where operation and insertion point are not in "
             "the same basic block");
   }
   // If `insertionPoint` does not dominate `op`, do nothing
@@ -1115,3 +1115,72 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
   DominanceInfo dominance(op);
   return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
 }
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+                                         ValueRange values,
+                                         Operation *insertionPoint,
+                                         DominanceInfo &dominance) {
+  // Remove the values that already dominate the insertion point.
+  SmallVector<Value> prunedValues;
+  for (auto value : values) {
+    if (dominance.properlyDominates(value, insertionPoint)) {
+      continue;
+    }
+    // Block arguments are not supported.
+    if (isa<BlockArgument>(value)) {
+      return rewriter.notifyMatchFailure(
+          insertionPoint,
+          "unsupported case of moving block argument before insertion point");
+    }
+    // Check for currently unsupported case if the insertion point is in a
+    // different block.
+    if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
+      return rewriter.notifyMatchFailure(
+          insertionPoint,
+          "unsupported case of moving definition of value before an insertion "
+          "point in a different basic block");
+    }
+    prunedValues.push_back(value);
+  }
+
+  // Find the backward slice of operation for each `Value` the operation
+  // depends on. Prune the slice to only include operations not already
+  // dominated by the `insertionPoint`
+  BackwardSliceOptions options;
+  options.inclusive = true;
+  options.omitUsesFromAbove = false;
+  // Since current support is to only move within a same basic block,
+  // the slices dont need to look past block arguments.
+  options.omitBlockArguments = true;
+  options.filter = [&](Operation *sliceBoundaryOp) {
+    return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+  };
+  llvm::SetVector<Operation *> slice;
+  for (auto value : prunedValues) {
+    getBackwardSlice(value, &slice, options);
+  }
+
+  // If the slice contains `insertionPoint` cannot move the dependencies.
+  if (slice.contains(insertionPoint)) {
+    return rewriter.notifyMatchFailure(
+        insertionPoint,
+        "cannot move dependencies before operation in backward slice of op");
+  }
+
+  // Sort operations topologically before moving.
+  mlir::topologicalSort(slice);
+
+  // We should move the slice in topological order, but `getBackwardSlice`
+  // already does that. So no need to sort again.
+  for (Operation *op : slice) {
+    rewriter.moveOpBefore(op, insertionPoint);
+  }
+  return success();
+}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+                                         ValueRange values,
+                                         Operation *insertionPoint) {
+  DominanceInfo dominance(insertionPoint);
+  return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+}
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
index 37637152938f6..aa7b5dc2a240a 100644
--- a/mlir/test/Transforms/move-operation-deps.mlir
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Check simple move value definitions before insertion operation.
+func.func @simple_move_values() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "foo"(%1, %2) : (f32, f32) -> (f32)
+  return %3 : f32
+}
+// CHECK-LABEL: func @simple_move_values()
+//       CHECK:   %[[MOVED1:.+]] = "moved_op_1"
+//       CHECK:   %[[MOVED2:.+]] = "moved_op_2"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Compute slice including the implicitly captured values.
+func.func @move_region_dependencies_values() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() ({
+    %3 = "inner_op"(%1) : (f32) -> (f32)
+    "yield"(%3) : (f32) -> ()
+  }) : () -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies_values()
+//       CHECK:   %[[MOVED1:.+]] = "moved_op_1"
+//       CHECK:   %[[MOVED2:.+]] = "moved_op_2"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1 before %op2
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_values_in_topological_sort_order() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "moved_op_3"(%1) : (f32) -> (f32)
+  %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+  %5 = "moved_op_5"(%2) : (f32) -> (f32)
+  %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+  return %6 : f32
+}
+// CHECK-LABEL: func @move_values_in_topological_sort_order()
+//       CHECK:   %[[MOVED_1:.+]] = "moved_op_1"
+//   CHECK-DAG:   %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+//   CHECK-DAG:   %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+//   CHECK-DAG:   %[[MOVED_4:.+]] = "moved_op_2"
+//   CHECK-DAG:   %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "dummy_op"() : () -> (f32)
+  %2 = "before"() : () -> (f32)
+  %3 = "moved_op"() : () -> (f32)
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[UNMOVED:.+]] = "unmoved_op"
+//       CHECK:   %[[DUMMY:.+]] = "dummy_op"
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "dummy_op"() : () -> (f32)
+  %2 = "before"() : () -> (f32)
+  %3 = "moved_op"() : () -> (f32)
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[UNMOVED:.+]] = "unmoved_op"
+//       CHECK:   %[[DUMMY:.+]] = "dummy_op"
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Check handling of block arguments
+func.func @move_only_required_defns() -> (f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  cf.br ^bb0(%0 : f32) 
+ ^bb0(%arg0 : f32) :
+  %1 = "before"() : () -> (f32)
+  %2 = "moved_op"(%arg0) : (f32) -> (f32)
+  return %1, %2 : f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1 before %op1
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Do not move across basic blocks
+func.func @no_move_across_basic_blocks() -> (f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "before"() : () -> (f32)
+  cf.br ^bb0(%0 : f32) 
+ ^bb0(%arg0 : f32) :
+  %2 = "moved_op"(%arg0) : (f32) -> (f32)
+  return %1, %2 : f32, f32
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    // expected-remark at +1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
+    transform.test.move_value_defns %v1 before %op1
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index aaa566d9938a3..3d95af59f6da3 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -39,6 +39,24 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure
+transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
+                                      TransformResults &TransformResults,
+                                      TransformState &state) {
+  SmallVector<Value> values;
+  for (auto tdValue : getValues()) {
+    values.push_back(*state.getPayloadValues(tdValue).begin());
+  }
+  Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+  if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
+    auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+    std::string errorMsg = listener->getLatestMatchFailureMessage();
+    (void)emitRemark(errorMsg);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+
 namespace {
 
 class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index f514702cef5bc..495579b452dfc 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -38,4 +38,26 @@ def TestMoveOperandDeps :
   }];
 }
 
+def TestMoveValueDefns :
+    Op<Transform_Dialect, "test.move_value_defns",
+        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+         DeclareOpInterfaceMethods<TransformOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Moves all dependencies of on operation before another operation.
+  }];
+
+  let arguments =
+    (ins Variadic<TransformValueHandleTypeInterface>:$values,
+         TransformHandleTypeInterface:$insertion_point);
+  
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $values `before` $insertion_point attr-dict
+    `:` `(` type($values) `)` `` `,` type($insertion_point)
+  }];
+}
+
+
 #endif // TEST_TRANSFORM_OPS

``````````

</details>


https://github.com/llvm/llvm-project/pull/130874


More information about the Mlir-commits mailing list