[Mlir-commits] [mlir] Enable LICM for ops with only read side effects in scf.for (PR #120302)

Arda Unal llvmlistbot at llvm.org
Mon Jan 6 15:45:21 PST 2025


https://github.com/ardaunal updated https://github.com/llvm/llvm-project/pull/120302

>From 0e596fdeb855ec27a4d78918971af819c0769825 Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Tue, 17 Dec 2024 12:29:21 -0800
Subject: [PATCH 1/3] Enable LICM for ops with read side effects in scf.for
 wrapped by a guard

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  3 +-
 .../mlir/Interfaces/LoopLikeInterface.td      | 20 ++++
 .../mlir/Interfaces/SideEffectInterfaces.h    |  4 +
 .../Transforms/LoopInvariantCodeMotionUtils.h | 12 ++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 59 ++++++++++-
 mlir/lib/Interfaces/SideEffectInterfaces.cpp  |  7 ++
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 96 +++++++++++++++---
 .../loop-invariant-code-motion.mlir           | 99 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 44 +++++++++
 9 files changed, 326 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
+    InterfaceMethod<[{
+        Wraps the loop into a trip-count check.
+      }],
+      /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+      /*methodName=*/"wrapInTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return ::mlir::failure();"
+    >,
+    InterfaceMethod<[{
+        Unwraps the trip-count check.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"unwrapTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
     InterfaceMethod<[{
         Promotes the loop body to its containing block if the loop is known to
         have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
 ///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
 /// }
 /// ```
 ///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
 ///   `value.getParentRegion()->isProperAncestor(region)`.
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-///   moved of the given region, e.g. if it is side-effect free.
+///   moved of the given region, e.g. if it is side-effect free or has only read
+///   side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
 /// - `moveOutOfRegion` moves the operation out of the given region. A common
 ///   implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+///   this check.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion);
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..d246ebdaaea2f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPointAfter(this->getOperation());
+
+  auto loc = this->getLoc();
+  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+                                               this->getUpperBound(),
+                                               this->getLowerBound());
+  scf::YieldOp yieldInThen;
+  // Create the trip-count check.
+  auto ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
+      },
+      [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, this->getInitArgs());
+      });
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(this->getResults(), ifOp.getResults()))
+    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+  // Move the scf.for into the then block.
+  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+  return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+  auto ifOp = (*this)->getParentRegion()->getParentOp();
+  if (!isa<scf::IfOp>(ifOp))
+    return failure();
+
+  IRRewriter rewriter(ifOp->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(ifOp);
+
+  auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+  if (!isa<arith::CmpIOp>(cmpOp))
+    return failure();
+
+  auto wrappedForOp = this->getOperation();
+  rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+  rewriter.eraseOp(ifOp);
+  rewriter.eraseOp(cmpOp);
+  return success();
+}
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3451,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
   return wouldOpBeTriviallyDeadImpl(op);
 }
 
+bool mlir::hasOnlyReadEffect(Operation *op) {
+  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+    return memEffects.onlyHasEffect<MemoryEffects::Read>();
+  }
+  return false;
+}
+
 bool mlir::isMemoryEffectFree(Operation *op) {
   if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
     if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..b0bf86c0c8e878 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
 }
 
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(OpOperand &)> condition) {
+  auto walkFn = [&](Operation *child) {
+    for (OpOperand &operand : child->getOpOperands()) {
+      if (!condition(operand))
+        return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  };
+  return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(Value)> definedOutsideGuard) {
+  return dependsOnGuarded(op, [&](OpOperand &operand) {
+    return definedOutsideGuard(operand.get());
+  });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+  for (Region &region : loop->getRegions()) {
+    for (Block &block : region.getBlocks()) {
+      for (Operation &op : block.getOperations()) {
+        if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+          return false;
+      }
+    }
+  }
+  return true;
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
+    auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+        loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+    size_t numMovedWithoutGuard = 0;
+
+    FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+    Region *loopRegion = region;
+    auto isLoopWrapped = false;
+    if (succeeded(ifOpAndRegion)) {
+      loopRegion = ifOpAndRegion->second;
+      isLoopWrapped = true;
+    }
+
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
-    for (Operation &op : region->getOps())
+    for (Operation &op : loopRegion->getOps())
       worklist.push(&op);
 
     auto definedOutside = [&](Value value) {
-      return isDefinedOutsideRegion(value, region);
+      return isDefinedOutsideRegion(value, loopRegion);
+    };
+
+    auto definedOutsideGuard = [&](Value value) {
+      return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
     };
 
     while (!worklist.empty()) {
       Operation *op = worklist.front();
       worklist.pop();
       // Skip ops that have already been moved. Check if the op can be hoisted.
-      if (op->getParentRegion() != region)
+      if (op->getParentRegion() != loopRegion)
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
-      if (!shouldMoveOutOfRegion(op, region) ||
+
+      if (!shouldMoveOutOfRegion(op, loopRegion) ||
           !canBeHoisted(op, definedOutside))
         continue;
+      // Can only hoist pure ops (side-effect free) when there is an op with
+      // write side effects in the loop.
+      if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+        continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-      moveOutOfRegion(op, region);
+
+      auto moveWithoutGuard = isMemoryEffectFree(op) &&
+                              !dependsOnGuarded(op, definedOutsideGuard) &&
+                              isLoopWrapped;
+      numMovedWithoutGuard += moveWithoutGuard;
+
+      moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+                                           : loopRegion);
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
       // top-level of the loop body.
       for (Operation *user : op->getUsers())
-        if (user->getParentRegion() == region)
+        if (user->getParentRegion() == loopRegion)
           worklist.push(user);
     }
+
+    // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+    if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+      auto tripCountCheckUnwrapped = unwrapGuard();
+      if (failed(tripCountCheckUnwrapped))
+        llvm_unreachable("Should not fail unwrapping trip-count check");
+    }
   }
 
   return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
       loopLike.getLoopRegions(),
-      [&](Value value, Region *) {
-        return loopLike.isDefinedOutsideOfLoop(value);
+      [&](Value value, Region *region) {
+        return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isMemoryEffectFree(op) && isSpeculatable(op);
+        return isSpeculatable(op) &&
+               (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+      },
+      [&]() { return loopLike.wrapInTripCountCheck(); },
+      [&](Operation *op, Region *region) {
+        op->moveBefore(region->getParentOp());
       },
-      [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+      [&]() { return loopLike.unwrapTripCountCheck(); });
 }
 
 namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..f2c4f09e1e7b22 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,105 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
   return
 }
 
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: arith.addi
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add = arith.addi %always_read, %cst_0 : i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %add : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: scf.for
+  // CHECK: scf.if
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.for %j = %lb to %ub step %step {
+      %eq42 = arith.cmpi eq, %j, %ind_42 : index
+      scf.if %eq42 {
+        %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+      }
+    }
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
 // -----
 
 func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d24d52f356d88f..1b834a0a9266b2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2916,6 +2916,50 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
   let regions = (region SizedRegion<1>:$body);
 }
 
+def SpeculatableOpWithReadSideEffect : TEST_Op<"speculatable_op_with_memread",
+    [ConditionallySpeculatable, MemoryEffects<[MemRead]>]> {
+  let description = [{
+    Op used to test LICM conditional speculation.  This op can always be
+    speculatively executed and has only memory read effect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let extraClassDeclaration = [{
+    ::mlir::Speculation::Speculatability getSpeculatability();
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::Speculation::Speculatability
+    SpeculatableOpWithReadSideEffect::getSpeculatability() {
+      return ::mlir::Speculation::Speculatable;
+    }
+  }];
+}
+
+def SpeculatableOpWithWriteSideEffect : TEST_Op<"speculatable_op_with_memwrite",
+    [ConditionallySpeculatable, MemoryEffects<[MemWrite]>]> {
+  let description = [{
+    Op used to test LICM conditional speculation.  This op can always be
+    speculatively executed and has only memory read effect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let extraClassDeclaration = [{
+    ::mlir::Speculation::Speculatability getSpeculatability();
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::Speculation::Speculatability
+    SpeculatableOpWithWriteSideEffect::getSpeculatability() {
+      return ::mlir::Speculation::Speculatable;
+    }
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // Test CSE
 //===---------------------------------------------------------------------===//

>From c5328108c35037fb408f1d4830231c6b18cbc74d Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Thu, 19 Dec 2024 17:01:55 -0800
Subject: [PATCH 2/3] fixup! Enable LICM for ops with read side effects in
 scf.for wrapped by a guard

---
 mlir/lib/Interfaces/SideEffectInterfaces.cpp    | 17 +++++++++++++++--
 .../Utils/LoopInvariantCodeMotionUtils.cpp      |  2 +-
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index f45d5f3d227407..ccbcb96eec58f5 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -308,9 +308,22 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
 
 bool mlir::hasOnlyReadEffect(Operation *op) {
   if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
-    return memEffects.onlyHasEffect<MemoryEffects::Read>();
+    if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+      return memEffects.onlyHasEffect<MemoryEffects::Read>();
+  } else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
+    // Otherwise, if the op does not implement the memory effect interface and
+    // it does not have recursive side effects, then it cannot be known that the
+    // op is moveable.
+    return false;
   }
-  return false;
+
+  // Recurse into the regions and ensure that all nested ops are memory effect
+  // free.
+  for (Region &region : op->getRegions())
+    for (Operation &op : region.getOps())
+      if (!hasOnlyReadEffect(&op))
+        return false;
+  return true;
 }
 
 bool mlir::isMemoryEffectFree(Operation *op) {
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index b0bf86c0c8e878..1d8fb568635b3f 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -139,7 +139,7 @@ size_t mlir::moveLoopInvariantCode(
           !canBeHoisted(op, definedOutside))
         continue;
       // Can only hoist pure ops (side-effect free) when there is an op with
-      // write side effects in the loop.
+      // write and/or unknown side effects in the loop.
       if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
         continue;
 

>From 8a38077cfb16f5ea4c7ab14a82d3891ae67d5d74 Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Mon, 6 Jan 2025 15:41:53 -0800
Subject: [PATCH 3/3] fixup! fixup! Enable LICM for ops with read side effects
 in scf.for wrapped by a guard

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |   3 +-
 .../mlir/Interfaces/LoopLikeInterface.td      |  20 +---
 .../mlir/Interfaces/SideEffectInterfaces.h    |   6 +-
 .../Transforms/LoopInvariantCodeMotionUtils.h |  18 ++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  69 ++++-------
 mlir/lib/Interfaces/SideEffectInterfaces.cpp  |  30 ++---
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 102 ++++-------------
 .../loop-invariant-code-motion.mlir           | 107 +++++++++++++-----
 8 files changed, 154 insertions(+), 201 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b54df8e3ef313d..0bac8bfbbb7499 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
     and then lowered to some final target like LLVM or SPIR-V.
   }];
 
-  let dependentDialects = ["arith::ArithDialect"];
+  let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
 }
 
 // Base class for SCF dialect ops.
@@ -138,6 +138,7 @@ def ForOp : SCF_Op<"for",
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
+        "moveOutOfLoopWithGuard",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 831830130b0ddc..c4ff0b485565e0 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -80,23 +80,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
     InterfaceMethod<[{
-        Wraps the loop into a trip-count check.
+        Moves the given loop-invariant operation out of the loop with a
+        trip-count guard.
       }],
-      /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
-      /*methodName=*/"wrapInTripCountCheck",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/"return ::mlir::failure();"
-    >,
-    InterfaceMethod<[{
-        Unwraps the trip-count check.
-      }],
-      /*retTy=*/"::llvm::LogicalResult",
-      /*methodName=*/"unwrapTripCountCheck",
-      /*args=*/(ins),
+      /*retTy=*/"void",
+      /*methodName=*/"moveOutOfLoopWithGuard",
+      /*args=*/(ins "::mlir::Operation *":$op),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::failure();
+        return;
       }]
     >,
     InterfaceMethod<[{
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index 1a7f66e2234949..f8142567c60607 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,9 +433,9 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
-/// Returns true if the given operation implements `MemoryEffectOpInterface` and
-/// has only read effects.
-bool hasOnlyReadEffect(Operation *op);
+/// Returns true if the given operation is free of memory effects or has only
+/// read effect.
+bool isMemoryEffectFreeOrOnlyRead(Operation *op);
 
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index ae6719abe79c00..83a70abe0a5edf 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -47,8 +47,7 @@ class Value;
 ///   }
 /// }
 /// ```
-///
-/// Users must supply five callbacks.
+/// Users must supply four callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
@@ -56,11 +55,11 @@ class Value;
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
 ///   moved of the given region, e.g. if it is side-effect free or has only read
 ///   side effects.
-/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
-/// - `moveOutOfRegion` moves the operation out of the given region. A common
-///   implementation might be: `op->moveBefore(region->getParentOp())`.
-/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
-///   this check.
+/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
+///   without a guard. A common implementation might be:
+///   `op->moveBefore(region->getParentOp())`.
+/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
+///   with a guard.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -70,9 +69,8 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion,
-    function_ref<LogicalResult()> unwrapGuard);
+    function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
+    function_ref<void(Operation *)> moveOutOfRegionWithGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index d246ebdaaea2f3..0974cc214f99f2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -395,58 +396,38 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
-FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
-
+/// Moves the op out of the loop with a guard that checks if the loop has at
+/// least one iteration.
+void ForOp::moveOutOfLoopWithGuard(Operation *op) {
   IRRewriter rewriter(this->getContext());
   OpBuilder::InsertionGuard insertGuard(rewriter);
-  rewriter.setInsertionPointAfter(this->getOperation());
-
-  auto loc = this->getLoc();
-  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
-                                               this->getUpperBound(),
-                                               this->getLowerBound());
-  scf::YieldOp yieldInThen;
+  rewriter.setInsertionPoint(this->getOperation());
+  Location loc = this->getLoc();
+  arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::ult, this->getLowerBound(),
+      this->getUpperBound());
   // Create the trip-count check.
-  auto ifOp = rewriter.create<scf::IfOp>(
+  scf::YieldOp thenYield;
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(
       loc, cmpIOp,
       [&](OpBuilder &builder, Location loc) {
-        yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
+        thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
       },
       [&](OpBuilder &builder, Location loc) {
-        builder.create<scf::YieldOp>(loc, this->getInitArgs());
+        SmallVector<Value> poisonResults;
+        poisonResults.reserve(op->getResults().size());
+        for (Type type : op->getResults().getTypes()) {
+          ub::PoisonOp poisonOp =
+              rewriter.create<ub::PoisonOp>(loc, type, nullptr);
+          poisonResults.push_back(poisonOp);
+        }
+        builder.create<scf::YieldOp>(loc, poisonResults);
       });
-
-  for (auto [forOpResult, ifOpResult] :
-       llvm::zip(this->getResults(), ifOp.getResults()))
-    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
-  // Move the scf.for into the then block.
-  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
-  return std::make_pair(ifOp.getOperation(), &this->getRegion());
-}
-
-LogicalResult ForOp::unwrapTripCountCheck() {
-  auto ifOp = (*this)->getParentRegion()->getParentOp();
-  if (!isa<scf::IfOp>(ifOp))
-    return failure();
-
-  IRRewriter rewriter(ifOp->getContext());
-  OpBuilder::InsertionGuard insertGuard(rewriter);
-  rewriter.setInsertionPoint(ifOp);
-
-  auto cmpOp = ifOp->getOperand(0).getDefiningOp();
-  if (!isa<arith::CmpIOp>(cmpOp))
-    return failure();
-
-  auto wrappedForOp = this->getOperation();
-  rewriter.moveOpBefore(wrappedForOp, ifOp);
-
-  for (auto [forOpResult, ifOpResult] :
-       llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
-    rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
-
-  rewriter.eraseOp(ifOp);
-  rewriter.eraseOp(cmpOp);
-  return success();
+  for (auto [opResult, ifOpResult] :
+       llvm::zip(op->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
+  // Move the op into the then block.
+  rewriter.moveOpBefore(op, thenYield);
 }
 
 /// Promotes the loop body of a forOp to its containing block if the forOp
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index ccbcb96eec58f5..9d681ba59b8b28 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,26 +306,6 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
   return wouldOpBeTriviallyDeadImpl(op);
 }
 
-bool mlir::hasOnlyReadEffect(Operation *op) {
-  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
-    if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
-      return memEffects.onlyHasEffect<MemoryEffects::Read>();
-  } else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
-    // Otherwise, if the op does not implement the memory effect interface and
-    // it does not have recursive side effects, then it cannot be known that the
-    // op is moveable.
-    return false;
-  }
-
-  // Recurse into the regions and ensure that all nested ops are memory effect
-  // free.
-  for (Region &region : op->getRegions())
-    for (Operation &op : region.getOps())
-      if (!hasOnlyReadEffect(&op))
-        return false;
-  return true;
-}
-
 bool mlir::isMemoryEffectFree(Operation *op) {
   if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
     if (!memInterface.hasNoEffect())
@@ -383,6 +363,16 @@ mlir::getEffectsRecursively(Operation *rootOp) {
   return effects;
 }
 
+bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
+  std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
+      getEffectsRecursively(op);
+  if (!effects)
+    return false;
+  return std::all_of(effects->begin(), effects->end(), [](auto &effect) {
+    return isa<MemoryEffects::Read>(effect.getEffect());
+  });
+}
+
 bool mlir::isSpeculatable(Operation *op) {
   auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
   if (!conditionallySpeculatable)
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 1d8fb568635b3f..094db01176a28c 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,86 +56,41 @@ static bool canBeHoisted(Operation *op,
       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
 }
 
-static bool dependsOnGuarded(Operation *op,
-                             function_ref<bool(OpOperand &)> condition) {
-  auto walkFn = [&](Operation *child) {
-    for (OpOperand &operand : child->getOpOperands()) {
-      if (!condition(operand))
-        return WalkResult::interrupt();
-    }
-    return WalkResult::advance();
-  };
-  return op->walk(walkFn).wasInterrupted();
-}
-
-static bool dependsOnGuarded(Operation *op,
-                             function_ref<bool(Value)> definedOutsideGuard) {
-  return dependsOnGuarded(op, [&](OpOperand &operand) {
-    return definedOutsideGuard(operand.get());
-  });
-}
-
-static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
-  for (Region &region : loop->getRegions()) {
-    for (Block &block : region.getBlocks()) {
-      for (Operation &op : block.getOperations()) {
-        if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
-          return false;
-      }
-    }
-  }
-  return true;
-}
-
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion,
-    function_ref<LogicalResult()> unwrapGuard) {
+    function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
+    function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
-    auto loopSideEffectFreeOrHasOnlyReadSideEffect =
-        loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
-
-    size_t numMovedWithoutGuard = 0;
-
-    FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
-    Region *loopRegion = region;
-    auto isLoopWrapped = false;
-    if (succeeded(ifOpAndRegion)) {
-      loopRegion = ifOpAndRegion->second;
-      isLoopWrapped = true;
-    }
+    bool anyOpHoistedWithGuard = false;
+    bool loopSideEffectFreeOrHasOnlyReadSideEffect =
+        isMemoryEffectFreeOrOnlyRead(region->getParentOp());
 
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
-    for (Operation &op : loopRegion->getOps())
+    for (Operation &op : region->getOps())
       worklist.push(&op);
 
     auto definedOutside = [&](Value value) {
-      return isDefinedOutsideRegion(value, loopRegion);
-    };
-
-    auto definedOutsideGuard = [&](Value value) {
-      return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
+      return isDefinedOutsideRegion(value, region);
     };
 
     while (!worklist.empty()) {
       Operation *op = worklist.front();
       worklist.pop();
       // Skip ops that have already been moved. Check if the op can be hoisted.
-      if (op->getParentRegion() != loopRegion)
+      if (op->getParentRegion() != region)
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
 
-      if (!shouldMoveOutOfRegion(op, loopRegion) ||
+      if (!shouldMoveOutOfRegion(op, region) ||
           !canBeHoisted(op, definedOutside))
         continue;
       // Can only hoist pure ops (side-effect free) when there is an op with
@@ -143,30 +98,25 @@ size_t mlir::moveLoopInvariantCode(
       if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
         continue;
 
-      LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-
-      auto moveWithoutGuard = isMemoryEffectFree(op) &&
-                              !dependsOnGuarded(op, definedOutsideGuard) &&
-                              isLoopWrapped;
-      numMovedWithoutGuard += moveWithoutGuard;
-
-      moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
-                                           : loopRegion);
+      bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
+      if (moveWithoutGuard) {
+        LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
+                                << " without guard\n");
+        moveOutOfRegionWithoutGuard(op);
+      } else {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Moving loop-invariant op: " << *op << " with guard\n");
+        moveOutOfRegionWithGuard(op);
+        anyOpHoistedWithGuard = true;
+      }
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
       // top-level of the loop body.
       for (Operation *user : op->getUsers())
-        if (user->getParentRegion() == loopRegion)
+        if (user->getParentRegion() == region)
           worklist.push(user);
     }
-
-    // Unwrap the loop if it was wrapped but no ops were moved in the guard.
-    if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
-      auto tripCountCheckUnwrapped = unwrapGuard();
-      if (failed(tripCountCheckUnwrapped))
-        llvm_unreachable("Should not fail unwrapping trip-count check");
-    }
   }
 
   return numMoved;
@@ -179,14 +129,10 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
         return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isSpeculatable(op) &&
-               (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
-      },
-      [&]() { return loopLike.wrapInTripCountCheck(); },
-      [&](Operation *op, Region *region) {
-        op->moveBefore(region->getParentOp());
+        return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
       },
-      [&]() { return loopLike.unwrapTripCountCheck(); });
+      [&](Operation *op) { loopLike.moveOutOfLoop(op); },
+      [&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
 }
 
 namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index f2c4f09e1e7b22..6a1ca519dec1dd 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -595,48 +595,93 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
 
 // CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
 func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
-  // CHECK: arith.cmpi
-  // CHECK-NEXT: test.always_speculatable_op
-  // CHECK-NEXT: scf.if
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+  // CHECK-NEXT: scf.if %[[CMP]]
   // CHECK-NEXT: test.speculatable_op_with_memread
-  // CHECK-NEXT: scf.for
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read : i32
+    scf.yield %sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
+func.func @test_speculatable_op_with_read_side_effect_multiple_result_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+  // CHECK-NEXT: scf.if %[[CMP]]
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK-NEXT: ub.poison : f32
+  // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
   // CHECK-NOT: test.always_speculatable_op
   // CHECK-NOT: test.speculatable_op_with_memread
   %cst_0 = arith.constant 0 : i32
   %cst_42 = arith.constant dense<42> : tensor<64xi32>
   %ind_42 = arith.constant 42 : index
   %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
-    %always = "test.always_speculatable_op"() : () -> i32
-    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read:2 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> (i32, f32)
     %i_cast = arith.index_cast %i: index to i32
-    %i_sum = arith.addi %acc, %i_cast : i32
-    %test_sum = arith.addi %i_sum, %always_read : i32
-    scf.yield %test_sum : i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read#0 : i32
+    scf.yield %sum : i32
   }
   return %sum_result : i32
 }
 
 // CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
 func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
-  // CHECK: arith.cmpi
-  // CHECK-NEXT: test.always_speculatable_op
-  // CHECK-NEXT: scf.if
+  // CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
+  // CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+  // CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
   // CHECK-NEXT: test.speculatable_op_with_memread
-  // CHECK-NEXT: arith.addi
-  // CHECK-NEXT: scf.for
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+  // CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
+  // CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+  // CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+  // CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
+  // CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
   // CHECK-NOT: test.always_speculatable_op
   // CHECK-NOT: test.speculatable_op_with_memread
   %cst_0 = arith.constant 0 : i32
   %cst_42 = arith.constant dense<42> : tensor<64xi32>
   %ind_42 = arith.constant 42 : index
   %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
-    %always = "test.always_speculatable_op"() : () -> i32
-    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
-    %add = arith.addi %always_read, %cst_0 : i32
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read_0 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add_0 = arith.addi %always_speculate, %only_read_0 : i32
+    %only_read_1 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add_1 = arith.addi %add_0, %only_read_1 : i32
     %i_cast = arith.index_cast %i: index to i32
-    %i_sum = arith.addi %acc, %i_cast : i32
-    %test_sum = arith.addi %i_sum, %add : i32
-    scf.yield %test_sum : i32
+    %sum = arith.addi %add_1, %i_cast : i32
+    scf.yield %sum : i32
   }
   return %sum_result : i32
 }
@@ -652,13 +697,13 @@ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb:
   %cst_42 = arith.constant dense<42> : tensor<64xi32>
   %ind_42 = arith.constant 42 : index
   %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
-    %always = "test.always_speculatable_op"() : () -> i32
-    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
     %i_cast = arith.index_cast %i: index to i32
-    %i_sum = arith.addi %acc, %i_cast : i32
-    %test_sum = arith.addi %i_sum, %always_read : i32
-    %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
-    scf.yield %test_sum : i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read : i32
+    %write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %sum : i32
   }
   return %sum_result : i32
 }
@@ -676,18 +721,18 @@ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_writ
   %cst_42 = arith.constant dense<42> : tensor<64xi32>
   %ind_42 = arith.constant 42 : index
   %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
-    %always = "test.always_speculatable_op"() : () -> i32
-    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
     %i_cast = arith.index_cast %i: index to i32
-    %i_sum = arith.addi %acc, %i_cast : i32
-    %test_sum = arith.addi %i_sum, %always_read : i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read : i32
     scf.for %j = %lb to %ub step %step {
       %eq42 = arith.cmpi eq, %j, %ind_42 : index
       scf.if %eq42 {
         %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
       }
     }
-    scf.yield %test_sum : i32
+    scf.yield %sum : i32
   }
   return %sum_result : i32
 }



More information about the Mlir-commits mailing list