[Mlir-commits] [mlir] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (PR #86108)
Steven Varoumas
llvmlistbot at llvm.org
Fri Apr 19 08:00:30 PDT 2024
https://github.com/stevenvar updated https://github.com/llvm/llvm-project/pull/86108
>From 81abb241a8ba5b62bba09817046f377e982ab194 Mon Sep 17 00:00:00 2001
From: Steven Varoumas <steven.varoumas1 at huawei.com>
Date: Thu, 21 Mar 2024 18:19:35 +0800
Subject: [PATCH 1/3] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast
pairs
---
.../Linalg/TransformOps/LinalgTransformOps.td | 36 ++++++
.../mlir/Dialect/Linalg/Transforms/Hoisting.h | 2 +
.../TransformOps/LinalgTransformOps.cpp | 14 +++
.../Dialect/Linalg/Transforms/Hoisting.cpp | 114 ++++++++++++++++++
mlir/test/Dialect/Linalg/hoisting.mlir | 109 +++++++++++++++++
5 files changed, 275 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8edaa7db6cef3b..313f2aca8f0c9f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2206,6 +2206,42 @@ def HoistRedundantVectorTransfersOp :
}];
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantVectorBroadcastsOp :
+ Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Hoist vector.extract / vector.broadcasts pairs out of immediately
+ enclosing scf::ForOp iteratively.
+
+ #### Return modes:
+
+ The operation always succeeds and returns a handle to the transformed
+ function op.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::func::FuncOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 186e83a57580f3..11886d4876a97f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -43,6 +43,8 @@ namespace linalg {
/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(Operation *root);
+void hoistRedundantVectorBroadcasts(Operation *root);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7e7cf1d0244613..7166bc19745d05 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3306,6 +3306,20 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantVectorBroadcastsOp::applyToOne(
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ linalg::hoistRedundantVectorBroadcasts(target);
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 34c9b2c282965c..98521cd745216c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,6 +43,120 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+ Value newInitOperand, int index,
+ Value newYieldValue) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loop.getOperation());
+ auto inits = llvm::to_vector(loop.getInits());
+
+ // Replace the init value with the new operand
+ inits[index] = newInitOperand;
+
+ scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+ loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+ inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+ // Generate the new yield with the replaced operand
+ auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+ yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+
+ // Move the loop body to the new op.
+ rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(
+ loop.getBody()->getNumArguments()));
+
+ // Replace the old loop.
+ rewriter.replaceOp(loop.getOperation(),
+ newLoop->getResults().take_front(loop.getNumResults()));
+ return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+// %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
+// %e = vector.extract %iterarg : t1 to t2
+// %u = // do something with %e : t2
+// %b = vector.broadcast %u : t2 to t1
+// scf.yield %b : t1
+// }
+// into the following:
+// %e = vector.extract %v: t1 to t2
+// %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
+// %u' = // do something with %iterarg : t2
+// scf.yield %u' : t2
+// }
+// %loop = vector.broadcast %loop' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ // First move loop invariant ops outside of their loop. This needs to be
+ // done before as we cannot move ops without interrupting the function walk.
+ root->walk(
+ [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+ root->walk([&](vector::ExtractOp extractOp) {
+ LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+ << *extractOp.getOperation() << "\n");
+
+ auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+ if (!loop)
+ return WalkResult::advance();
+
+ // Check that the vector to extract from is an iter_arg
+ auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+ if (!blockArg)
+ return WalkResult::advance();
+
+ // If the iter_arg does not have only one use, it won't be possible to
+ // hoist the extractOp out.
+ if (!blockArg.hasOneUse())
+ return WalkResult::advance();
+
+ auto initArg = loop.getTiedLoopInit(blockArg)->get();
+ auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+ // Check that the loop yields a broadcast
+ auto lastOp = loop.getBody()->getTerminator();
+ auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
+ if (!yieldOp)
+ return WalkResult::advance();
+
+ auto broadcast = dyn_cast<vector::BroadcastOp>(
+ yieldOp->getOperand(index).getDefiningOp());
+
+ LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+ Type broadcastInputType = broadcast.getSourceType();
+ if (broadcastInputType != extractOp.getType())
+ return WalkResult::advance();
+
+ // The position of the extract must be defined outside of the loop if
+ // it is dynamic
+ for (auto operand : extractOp.getDynamicPosition())
+ if (!loop.isDefinedOutsideOfLoop(operand))
+ return WalkResult::advance();
+
+ extractOp.getVectorMutable().assign(initArg);
+ loop.moveOutOfLoop(extractOp);
+ broadcast->moveAfter(loop);
+
+ IRRewriter rewriter(extractOp.getContext());
+ auto newLoop = replaceWithDifferentYield(
+ rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
+
+ LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
+
+ newLoop.getResult(index).replaceAllUsesWith(broadcast);
+ broadcast.getSourceMutable().assign(newLoop.getResult(index));
+
+ changed = true;
+ return WalkResult::interrupt();
+ });
+ }
+}
+
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
LoopLikeOpInterface loop) {
Value source = transferRead.getSource();
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 550ffbc7bab678..5a640348be90cf 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
+// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
+// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+ %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+ %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+ %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+ scf.yield %broadcast : vector<3x4xf32>
+ }
+ return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
+// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
+// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+ %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
+ %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+ %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+ scf.yield %broadcast : vector<3x4xf32>
+ }
+ return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
+// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
+// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
+// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
+// CHECK-NEXT: }
+// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
+// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
+// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
+
+func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+ %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
+ %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
+ %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
+ %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
+ %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
+ scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
+ }
+ return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
\ No newline at end of file
>From 854e4d1ffad22d5a2871415eed321c71e343ed1c Mon Sep 17 00:00:00 2001
From: Steven Varoumas <steven.varoumas1 at huawei.com>
Date: Mon, 15 Apr 2024 22:27:56 +0800
Subject: [PATCH 2/3] review comments
---
.../Linalg/TransformOps/LinalgTransformOps.td | 2 +-
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Dialect/Linalg/Transforms/Hoisting.cpp | 35 ++++++++++---------
3 files changed, 20 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 313f2aca8f0c9f..157dc671f72008 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2236,7 +2236,7 @@ def HoistRedundantVectorBroadcastsOp :
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
- ::mlir::func::FuncOp target,
+ ::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7166bc19745d05..dc2c86a664f440 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3312,7 +3312,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
DiagnosedSilenceableFailure
transform::HoistRedundantVectorBroadcastsOp::applyToOne(
- transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::TransformRewriter &rewriter, mlir::Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
linalg::hoistRedundantVectorBroadcasts(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 98521cd745216c..9a34b7b11c5702 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -59,7 +59,7 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
// Generate the new yield with the replaced operand
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
- yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+ rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
// Move the loop body to the new op.
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
@@ -74,19 +74,19 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
// Hoist out a pair of corresponding vector.extract+vector.broadcast
// operations. This function transforms a loop like this:
-// %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
-// %e = vector.extract %iterarg : t1 to t2
-// %u = // do something with %e : t2
+// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
+// %e = vector.extract %iarg : t1 to t2
+// %u = "some_use"(%e) : (t2) -> t2
// %b = vector.broadcast %u : t2 to t1
// scf.yield %b : t1
// }
// into the following:
// %e = vector.extract %v: t1 to t2
-// %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
-// %u' = // do something with %iterarg : t2
+// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
+// %u' = "some_use"(%iarg) : (t2) -> t2
// scf.yield %u' : t2
// }
-// %loop = vector.broadcast %loop' : t2 to t1
+// %res = vector.broadcast %res' : t2 to t1
void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
bool changed = true;
while (changed) {
@@ -118,14 +118,12 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
// Check that the loop yields a broadcast
- auto lastOp = loop.getBody()->getTerminator();
- auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
- if (!yieldOp)
+ auto yieldedVal =
+ loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
+ auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
+ if (!broadcast)
return WalkResult::advance();
- auto broadcast = dyn_cast<vector::BroadcastOp>(
- yieldOp->getOperand(index).getDefiningOp());
-
LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
Type broadcastInputType = broadcast.getSourceType();
@@ -138,18 +136,21 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();
+ IRRewriter rewriter(extractOp.getContext());
+
extractOp.getVectorMutable().assign(initArg);
loop.moveOutOfLoop(extractOp);
- broadcast->moveAfter(loop);
+ rewriter.moveOpAfter(broadcast, loop);
- IRRewriter rewriter(extractOp.getContext());
auto newLoop = replaceWithDifferentYield(
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
- newLoop.getResult(index).replaceAllUsesWith(broadcast);
- broadcast.getSourceMutable().assign(newLoop.getResult(index));
+ rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
+ rewriter.modifyOpInPlace(broadcast, [&] {
+ broadcast.getSourceMutable().assign(newLoop.getResult(index));
+ });
changed = true;
return WalkResult::interrupt();
>From bf4eb782366e61ecf4316f0d237508d07e86d708 Mon Sep 17 00:00:00 2001
From: Steven Varoumas <steven.varoumas1 at huawei.com>
Date: Fri, 19 Apr 2024 18:18:45 +0800
Subject: [PATCH 3/3] review comments
---
.../mlir/Dialect/Linalg/Transforms/Hoisting.h | 11 +++-
.../TransformOps/LinalgTransformOps.cpp | 3 +-
.../Dialect/Linalg/Transforms/Hoisting.cpp | 61 +++++++++++--------
mlir/test/Dialect/Linalg/hoisting.mlir | 2 +-
4 files changed, 49 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 11886d4876a97f..236c2ce7d48e39 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -43,7 +43,16 @@ namespace linalg {
/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(Operation *root);
-void hoistRedundantVectorBroadcasts(Operation *root);
+/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
+/// scf::ForOp iteratively, if the following conditions are met:
+/// 1. The vector.extract operation is applied on an iter_argument, and no
+/// other operator is using this argument in the body of the loop.
+/// 2. The position of the vector.extract is either a static value, or defined
+/// outside of the loop.
+/// 3. The vector.broadcast operation is yielded by the loop.
+/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
+/// function on the candidate loop above which to hoist.
+void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dc2c86a664f440..8f97c83e29cefc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3315,7 +3315,8 @@ transform::HoistRedundantVectorBroadcastsOp::applyToOne(
transform::TransformRewriter &rewriter, mlir::Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- linalg::hoistRedundantVectorBroadcasts(target);
+ rewriter.setInsertionPoint(target);
+ linalg::hoistRedundantVectorBroadcasts(rewriter, target);
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 9a34b7b11c5702..94f6b602987555 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,32 +43,39 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;
-scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
- Value newInitOperand, int index,
- Value newYieldValue) {
+/// Replace `loop` with a new loop that has a different init operand at
+/// position `index`. The body of this loop is moved over to the new loop.
+///
+/// `newInitOperands` specifies the replacement "init" operands.
+/// `newYieldValue` is the replacement yield value of the loop at position
+/// `index`.
+static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
+ scf::ForOp loop,
+ Value newInitOperand,
+ unsigned index,
+ Value newYieldValue) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop.getOperation());
auto inits = llvm::to_vector(loop.getInits());
- // Replace the init value with the new operand
+ // Replace the init value with the new operand.
+ assert(index < inits.size());
inits[index] = newInitOperand;
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
inits, [](OpBuilder &, Location, Value, ValueRange) {});
- // Generate the new yield with the replaced operand
+ // Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
- rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
+ yieldOp.setOperand(index, newYieldValue);
// Move the loop body to the new op.
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
- newLoop.getBody()->getArguments().take_front(
- loop.getBody()->getNumArguments()));
+ newLoop.getBody()->getArguments());
// Replace the old loop.
- rewriter.replaceOp(loop.getOperation(),
- newLoop->getResults().take_front(loop.getNumResults()));
+ rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
return newLoop;
}
@@ -87,7 +94,8 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
// scf.yield %u' : t2
// }
// %res = vector.broadcast %res' : t2 to t1
-void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
+ Operation *root) {
bool changed = true;
while (changed) {
changed = false;
@@ -104,24 +112,28 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
if (!loop)
return WalkResult::advance();
- // Check that the vector to extract from is an iter_arg
+ // Check that the vector to extract from is a BlockArgument.
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
if (!blockArg)
return WalkResult::advance();
+ // Check that the blockArg is an iter_arg of the loop.
+ OpOperand *initArg = loop.getTiedLoopInit(blockArg);
+ if (!initArg)
+ return WalkResult::advance();
+
// If the iter_arg does not have only one use, it won't be possible to
// hoist the extractOp out.
if (!blockArg.hasOneUse())
return WalkResult::advance();
- auto initArg = loop.getTiedLoopInit(blockArg)->get();
- auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+ unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
- // Check that the loop yields a broadcast
- auto yieldedVal =
+ // Check that the loop yields a broadcast that has just one use.
+ Operation *yieldedVal =
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
- if (!broadcast)
+ if (!broadcast || !broadcast.getResult().hasOneUse())
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
@@ -131,26 +143,25 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
return WalkResult::advance();
// The position of the extract must be defined outside of the loop if
- // it is dynamic
+ // it is dynamic.
for (auto operand : extractOp.getDynamicPosition())
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();
- IRRewriter rewriter(extractOp.getContext());
-
- extractOp.getVectorMutable().assign(initArg);
+ rewriter.modifyOpInPlace(broadcast, [&] {
+ extractOp.getVectorMutable().assign(initArg->get());
+ });
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);
- auto newLoop = replaceWithDifferentYield(
+ scf::ForOp newLoop = replaceWithDifferentYield(
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
- rewriter.modifyOpInPlace(broadcast, [&] {
- broadcast.getSourceMutable().assign(newLoop.getResult(index));
- });
+ rewriter.modifyOpInPlace(
+ broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
changed = true;
return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 5a640348be90cf..241b8a486c012e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -673,4 +673,4 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
transform.yield
}
-}
\ No newline at end of file
+}
More information about the Mlir-commits
mailing list