[Mlir-commits] [mlir] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (PR #86108)

Steven Varoumas llvmlistbot at llvm.org
Tue Apr 16 01:53:37 PDT 2024


https://github.com/stevenvar updated https://github.com/llvm/llvm-project/pull/86108

>From a24e055ed1dbc6dfe9227a862417b2ef923958d8 Mon Sep 17 00:00:00 2001
From: Steven Varoumas <steven.varoumas1 at huawei.com>
Date: Thu, 21 Mar 2024 18:19:35 +0800
Subject: [PATCH 1/2] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast
 pairs

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  36 ++++++
 .../mlir/Dialect/Linalg/Transforms/Hoisting.h |   2 +
 .../TransformOps/LinalgTransformOps.cpp       |  14 +++
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 114 ++++++++++++++++++
 mlir/test/Dialect/Linalg/hoisting.mlir        | 109 +++++++++++++++++
 5 files changed, 275 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8edaa7db6cef3b..313f2aca8f0c9f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2206,6 +2206,42 @@ def HoistRedundantVectorTransfersOp :
    }];
 }
 
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantVectorBroadcastsOp :
+  Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformEachOpTrait, TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Hoist vector.extract / vector.broadcasts pairs out of immediately
+    enclosing scf::ForOp iteratively.
+
+    #### Return modes:
+
+    The operation always succeeds and returns a handle to the transformed
+    function op.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+         ::mlir::transform::TransformRewriter &rewriter,
+         ::mlir::func::FuncOp target,
+         ::mlir::transform::ApplyToEachResultList &results,
+         ::mlir::transform::TransformState &state);
+   }];
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertConv2DToImg2ColOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 186e83a57580f3..11886d4876a97f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -43,6 +43,8 @@ namespace linalg {
 /// when used on distributed loops with memref semantics!
 void hoistRedundantVectorTransfers(Operation *root);
 
+void hoistRedundantVectorBroadcasts(Operation *root);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7e7cf1d0244613..7166bc19745d05 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3306,6 +3306,20 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantVectorBroadcastsOp::applyToOne(
+    transform::TransformRewriter &rewriter, func::FuncOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  linalg::hoistRedundantVectorBroadcasts(target);
+  results.push_back(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertConv2DToImg2ColOp.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 34c9b2c282965c..98521cd745216c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,6 +43,120 @@ using llvm::dbgs;
 using namespace mlir;
 using namespace mlir::linalg;
 
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+                                     Value newInitOperand, int index,
+                                     Value newYieldValue) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loop.getOperation());
+  auto inits = llvm::to_vector(loop.getInits());
+
+  // Replace the init value with the new operand
+  inits[index] = newInitOperand;
+
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Generate the new yield with the replaced operand
+  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+  yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           loop.getBody()->getNumArguments()));
+
+  // Replace the old loop.
+  rewriter.replaceOp(loop.getOperation(),
+                     newLoop->getResults().take_front(loop.getNumResults()));
+  return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+//  %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
+//   %e = vector.extract %iterarg : t1 to t2
+//   %u = // do something with %e : t2
+//   %b = vector.broadcast %u : t2 to t1
+//   scf.yield %b : t1
+//  }
+// into the following:
+//  %e = vector.extract %v: t1 to t2
+//  %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
+//   %u' = // do something with %iterarg : t2
+//   scf.yield %u' : t2
+//  }
+//  %loop = vector.broadcast %loop' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    // First move loop invariant ops outside of their loop. This needs to be
+    // done before as we cannot move ops without interrupting the function walk.
+    root->walk(
+        [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+    root->walk([&](vector::ExtractOp extractOp) {
+      LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+                        << *extractOp.getOperation() << "\n");
+
+      auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+      if (!loop)
+        return WalkResult::advance();
+
+      // Check that the vector to extract from is an iter_arg
+      auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+      if (!blockArg)
+        return WalkResult::advance();
+
+      // If the iter_arg does not have only one use, it won't be possible to
+      // hoist the extractOp out.
+      if (!blockArg.hasOneUse())
+        return WalkResult::advance();
+
+      auto initArg = loop.getTiedLoopInit(blockArg)->get();
+      auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+      // Check that the loop yields a broadcast
+      auto lastOp = loop.getBody()->getTerminator();
+      auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
+      if (!yieldOp)
+        return WalkResult::advance();
+
+      auto broadcast = dyn_cast<vector::BroadcastOp>(
+          yieldOp->getOperand(index).getDefiningOp());
+
+      LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+      Type broadcastInputType = broadcast.getSourceType();
+      if (broadcastInputType != extractOp.getType())
+        return WalkResult::advance();
+
+      // The position of the extract must be defined outside of the loop if
+      // it is dynamic
+      for (auto operand : extractOp.getDynamicPosition())
+        if (!loop.isDefinedOutsideOfLoop(operand))
+          return WalkResult::advance();
+
+      extractOp.getVectorMutable().assign(initArg);
+      loop.moveOutOfLoop(extractOp);
+      broadcast->moveAfter(loop);
+
+      IRRewriter rewriter(extractOp.getContext());
+      auto newLoop = replaceWithDifferentYield(
+          rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
+
+      LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
+
+      newLoop.getResult(index).replaceAllUsesWith(broadcast);
+      broadcast.getSourceMutable().assign(newLoop.getResult(index));
+
+      changed = true;
+      return WalkResult::interrupt();
+    });
+  }
+}
+
 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
                                 LoopLikeOpInterface loop) {
   Value source = transferRead.getSource();
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 550ffbc7bab678..5a640348be90cf 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs
+
+// CHECK-LABEL:  func.func @hoist_vector_broadcasts
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
+//       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
+//       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+//       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
+//       CHECK-NEXT:   }
+//       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+//       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+  %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+    %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+    %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+    %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+    scf.yield %broadcast : vector<3x4xf32>
+  }
+  return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_broadcasts %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
+
+// CHECK-LABEL:  func.func @hoist_vector_broadcasts
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
+//       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
+//       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+//       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
+//       CHECK-NEXT:   }
+//       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+//       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+  %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+    %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
+    %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+    %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+    scf.yield %broadcast : vector<3x4xf32>
+  }
+  return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_broadcasts %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
+
+// CHECK-LABEL:  func.func @hoist_vector_broadcasts_multiple
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+//       CHECK-SAME:  %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+//       CHECK-DAG:     %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
+//       CHECK-DAG:     %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
+//       CHECK-NEXT:    %[[LOOP:.+]]:2 = scf.for {{.*}} {
+//       CHECK-DAG:       %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+//       CHECK-DAG:       %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
+//       CHECK-NEXT:      scf.yield %[[USE1]], %[[USE2]]  : vector<4xf32>, vector<5xf32>
+//       CHECK-NEXT:    }
+//       CHECK-DAG:     %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
+//       CHECK-DAG:     %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
+//       CHECK-NEXT:    return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
+
+func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
+  %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
+    %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+    %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
+    %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
+    %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
+    %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
+    %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
+    scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
+  }
+  return %bcast_vec#0, %bcast_vec#1 :  vector<3x4xf32>, vector<3x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_broadcasts %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
\ No newline at end of file

>From 13d780ee753f78efd85807d8b07637008a2466d1 Mon Sep 17 00:00:00 2001
From: Steven Varoumas <steven.varoumas1 at huawei.com>
Date: Mon, 15 Apr 2024 22:27:56 +0800
Subject: [PATCH 2/2] review comments

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  2 +-
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 35 ++++++++++---------
 3 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 313f2aca8f0c9f..157dc671f72008 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2236,7 +2236,7 @@ def HoistRedundantVectorBroadcastsOp :
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
          ::mlir::transform::TransformRewriter &rewriter,
-         ::mlir::func::FuncOp target,
+         ::mlir::Operation *target,
          ::mlir::transform::ApplyToEachResultList &results,
          ::mlir::transform::TransformState &state);
    }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7166bc19745d05..dc2c86a664f440 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3312,7 +3312,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
 
 DiagnosedSilenceableFailure
 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
-    transform::TransformRewriter &rewriter, func::FuncOp target,
+    transform::TransformRewriter &rewriter, mlir::Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   linalg::hoistRedundantVectorBroadcasts(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 98521cd745216c..9a34b7b11c5702 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -59,7 +59,7 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
 
   // Generate the new yield with the replaced operand
   auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
-  yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+  rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
 
   // Move the loop body to the new op.
   rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
@@ -74,19 +74,19 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
 
 // Hoist out a pair of corresponding vector.extract+vector.broadcast
 // operations. This function transforms a loop like this:
-//  %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
-//   %e = vector.extract %iterarg : t1 to t2
-//   %u = // do something with %e : t2
+//  %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
+//   %e = vector.extract %iarg : t1 to t2
+//   %u = "some_use"(%e) : (t2) -> t2
 //   %b = vector.broadcast %u : t2 to t1
 //   scf.yield %b : t1
 //  }
 // into the following:
 //  %e = vector.extract %v: t1 to t2
-//  %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
-//   %u' = // do something with %iterarg : t2
+//  %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
+//   %u' = "some_use"(%iarg) : (t2) -> t2
 //   scf.yield %u' : t2
 //  }
-//  %loop = vector.broadcast %loop' : t2 to t1
+//  %res = vector.broadcast %res' : t2 to t1
 void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
   bool changed = true;
   while (changed) {
@@ -118,14 +118,12 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
       auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
 
       // Check that the loop yields a broadcast
-      auto lastOp = loop.getBody()->getTerminator();
-      auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
-      if (!yieldOp)
+      auto yieldedVal =
+          loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
+      auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
+      if (!broadcast)
         return WalkResult::advance();
 
-      auto broadcast = dyn_cast<vector::BroadcastOp>(
-          yieldOp->getOperand(index).getDefiningOp());
-
       LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
 
       Type broadcastInputType = broadcast.getSourceType();
@@ -138,18 +136,21 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
         if (!loop.isDefinedOutsideOfLoop(operand))
           return WalkResult::advance();
 
+      IRRewriter rewriter(extractOp.getContext());
+
       extractOp.getVectorMutable().assign(initArg);
       loop.moveOutOfLoop(extractOp);
-      broadcast->moveAfter(loop);
+      rewriter.moveOpAfter(broadcast, loop);
 
-      IRRewriter rewriter(extractOp.getContext());
       auto newLoop = replaceWithDifferentYield(
           rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
 
       LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
 
-      newLoop.getResult(index).replaceAllUsesWith(broadcast);
-      broadcast.getSourceMutable().assign(newLoop.getResult(index));
+      rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
+      rewriter.modifyOpInPlace(broadcast, [&] {
+        broadcast.getSourceMutable().assign(newLoop.getResult(index));
+      });
 
       changed = true;
       return WalkResult::interrupt();



More information about the Mlir-commits mailing list