[Mlir-commits] [mlir] [mlir][Transforms] Add a utility method to move value definitions. (PR #130874)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 11 21:55:52 PDT 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/130874
>From 00624428051e99ff9206942bbdab444905cf2eca Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Tue, 11 Mar 2025 18:21:25 -0700
Subject: [PATCH 1/2] [mlir][Transforms] Add a utility method to move value
definitions.
https://github.com/llvm/llvm-project/commit/205c5325b3c771d94feb0ec07e8ad89d27c2b29e
added a transform utility that moved all SSA dependences of an
operation before an insertion point. Similar to that, this PR adds a
transform utility function, `moveValueDefinitions` to move the slice
of operations that define all values in a `ValueRange` before the
insertion point. While very similar to `moveOperationDependencies`,
this method differs in a few ways
1. When computing the backward slice since the start of the slice is
value, the slice computed needs to be inclusive.
2. The combined backward slice needs to be sorted topologically before
moving them to avoid SSA use-def violations while moving individual
ops.
The PR also adds a new transform op to test this new utility function.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
mlir/include/mlir/Transforms/RegionUtils.h | 11 +
mlir/lib/Transforms/Utils/RegionUtils.cpp | 71 +++++-
mlir/test/Transforms/move-operation-deps.mlir | 226 ++++++++++++++++++
.../test/lib/Transforms/TestTransformsOps.cpp | 18 ++
mlir/test/lib/Transforms/TestTransformsOps.td | 22 ++
5 files changed, 347 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index e6b928d8ebecc..2ed96afbace81 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SetVector.h"
@@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);
+/// Move definitions of `values` before an insertion point. Current support is
+/// only for movement of definitions within the same basic block. Note that this
+/// is an all-or-nothing approach. Either definitions of all values are moved
+/// before insertion point, or none of them are.
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+ Operation *insertionPoint,
+ DominanceInfo &dominance);
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+ Operation *insertionPoint);
+
/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
/// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index da0d486f0fdcb..6987a13b309d7 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
- op, "unsupported caes where operation and insertion point are not in "
+ op, "unsupported case where operation and insertion point are not in "
"the same basic block");
}
// If `insertionPoint` does not dominate `op`, do nothing
@@ -1115,3 +1115,72 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
DominanceInfo dominance(op);
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+ ValueRange values,
+ Operation *insertionPoint,
+ DominanceInfo &dominance) {
+ // Remove the values that already dominate the insertion point.
+ SmallVector<Value> prunedValues;
+ for (auto value : values) {
+ if (dominance.properlyDominates(value, insertionPoint)) {
+ continue;
+ }
+ // Block arguments are not supported.
+ if (isa<BlockArgument>(value)) {
+ return rewriter.notifyMatchFailure(
+ insertionPoint,
+ "unsupported case of moving block argument before insertion point");
+ }
+ // Check for currently unsupported case if the insertion point is in a
+ // different block.
+ if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
+ return rewriter.notifyMatchFailure(
+ insertionPoint,
+ "unsupported case of moving definition of value before an insertion "
+ "point in a different basic block");
+ }
+ prunedValues.push_back(value);
+ }
+
+ // Find the backward slice of operation for each `Value` the operation
+ // depends on. Prune the slice to only include operations not already
+ // dominated by the `insertionPoint`
+ BackwardSliceOptions options;
+ options.inclusive = true;
+ options.omitUsesFromAbove = false;
+ // Since current support is to only move within a same basic block,
+ // the slices dont need to look past block arguments.
+ options.omitBlockArguments = true;
+ options.filter = [&](Operation *sliceBoundaryOp) {
+ return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ };
+ llvm::SetVector<Operation *> slice;
+ for (auto value : prunedValues) {
+ getBackwardSlice(value, &slice, options);
+ }
+
+ // If the slice contains `insertionPoint` cannot move the dependencies.
+ if (slice.contains(insertionPoint)) {
+ return rewriter.notifyMatchFailure(
+ insertionPoint,
+ "cannot move dependencies before operation in backward slice of op");
+ }
+
+ // Sort operations topologically before moving.
+ mlir::topologicalSort(slice);
+
+ // We should move the slice in topological order, but `getBackwardSlice`
+ // already does that. So no need to sort again.
+ for (Operation *op : slice) {
+ rewriter.moveOpBefore(op, insertionPoint);
+ }
+ return success();
+}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+ ValueRange values,
+ Operation *insertionPoint) {
+ DominanceInfo dominance(insertionPoint);
+ return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+}
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
index 37637152938f6..aa7b5dc2a240a 100644
--- a/mlir/test/Transforms/move-operation-deps.mlir
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Check simple move value definitions before insertion operation.
+func.func @simple_move_values() -> f32 {
+ %0 = "before"() : () -> (f32)
+ %1 = "moved_op_1"() : () -> (f32)
+ %2 = "moved_op_2"() : () -> (f32)
+ %3 = "foo"(%1, %2) : (f32, f32) -> (f32)
+ return %3 : f32
+}
+// CHECK-LABEL: func @simple_move_values()
+// CHECK: %[[MOVED1:.+]] = "moved_op_1"
+// CHECK: %[[MOVED2:.+]] = "moved_op_2"
+// CHECK: %[[BEFORE:.+]] = "before"
+// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
+// CHECK: return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Compute slice including the implicitly captured values.
+func.func @move_region_dependencies_values() -> f32 {
+ %0 = "before"() : () -> (f32)
+ %1 = "moved_op_1"() : () -> (f32)
+ %2 = "moved_op_2"() ({
+ %3 = "inner_op"(%1) : (f32) -> (f32)
+ "yield"(%3) : (f32) -> ()
+ }) : () -> (f32)
+ return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies_values()
+// CHECK: %[[MOVED1:.+]] = "moved_op_1"
+// CHECK: %[[MOVED2:.+]] = "moved_op_2"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1 before %op2
+ : (!transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_values_in_topological_sort_order() -> f32 {
+ %0 = "before"() : () -> (f32)
+ %1 = "moved_op_1"() : () -> (f32)
+ %2 = "moved_op_2"() : () -> (f32)
+ %3 = "moved_op_3"(%1) : (f32) -> (f32)
+ %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+ %5 = "moved_op_5"(%2) : (f32) -> (f32)
+ %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+ return %6 : f32
+}
+// CHECK-LABEL: func @move_values_in_topological_sort_order()
+// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
+// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
+// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+// CHECK: %[[BEFORE:.+]] = "before"
+// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+// CHECK: return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ %1 = "dummy_op"() : () -> (f32)
+ %2 = "before"() : () -> (f32)
+ %3 = "moved_op"() : () -> (f32)
+ return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
+// CHECK: %[[DUMMY:.+]] = "dummy_op"
+// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ %1 = "dummy_op"() : () -> (f32)
+ %2 = "before"() : () -> (f32)
+ %3 = "moved_op"() : () -> (f32)
+ return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
+// CHECK: %[[DUMMY:.+]] = "dummy_op"
+// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Check handling of block arguments
+func.func @move_only_required_defns() -> (f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ cf.br ^bb0(%0 : f32)
+ ^bb0(%arg0 : f32) :
+ %1 = "before"() : () -> (f32)
+ %2 = "moved_op"(%arg0) : (f32) -> (f32)
+ return %1, %2 : f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1 before %op1
+ : (!transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Do not move across basic blocks
+func.func @no_move_across_basic_blocks() -> (f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ %1 = "before"() : () -> (f32)
+ cf.br ^bb0(%0 : f32)
+ ^bb0(%arg0 : f32) :
+ %2 = "moved_op"(%arg0) : (f32) -> (f32)
+ return %1, %2 : f32, f32
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ // expected-remark at +1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
+ transform.test.move_value_defns %v1 before %op1
+ : (!transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index aaa566d9938a3..3d95af59f6da3 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -39,6 +39,24 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure
+transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
+ TransformResults &TransformResults,
+ TransformState &state) {
+ SmallVector<Value> values;
+ for (auto tdValue : getValues()) {
+ values.push_back(*state.getPayloadValues(tdValue).begin());
+ }
+ Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+ if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
+ auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+ std::string errorMsg = listener->getLatestMatchFailureMessage();
+ (void)emitRemark(errorMsg);
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+
namespace {
class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index f514702cef5bc..495579b452dfc 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -38,4 +38,26 @@ def TestMoveOperandDeps :
}];
}
+def TestMoveValueDefns :
+ Op<Transform_Dialect, "test.move_value_defns",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Moves all dependencies of on operation before another operation.
+ }];
+
+ let arguments =
+ (ins Variadic<TransformValueHandleTypeInterface>:$values,
+ TransformHandleTypeInterface:$insertion_point);
+
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $values `before` $insertion_point attr-dict
+ `:` `(` type($values) `)` `` `,` type($insertion_point)
+ }];
+}
+
+
#endif // TEST_TRANSFORM_OPS
>From 1122f2d86936b253519c6ad2df659218a1d4a110 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Tue, 11 Mar 2025 21:55:29 -0700
Subject: [PATCH 2/2] Address comments.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 2 --
mlir/test/lib/Transforms/TestTransformsOps.cpp | 5 ++---
2 files changed, 2 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 6987a13b309d7..18e079d153161 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1170,8 +1170,6 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Sort operations topologically before moving.
mlir::topologicalSort(slice);
- // We should move the slice in topological order, but `getBackwardSlice`
- // already does that. So no need to sort again.
for (Operation *op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index 3d95af59f6da3..c05b32bed9b94 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -41,8 +41,8 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
DiagnosedSilenceableFailure
transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
- TransformResults &TransformResults,
- TransformState &state) {
+ TransformResults &TransformResults,
+ TransformState &state) {
SmallVector<Value> values;
for (auto tdValue : getValues()) {
values.push_back(*state.getPayloadValues(tdValue).begin());
@@ -56,7 +56,6 @@ transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
-
namespace {
class TestTransformsDialectExtension
More information about the Mlir-commits
mailing list