[Mlir-commits] [mlir] Enable LICM for ops with only read side effects in scf.for (PR #120302)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 17 12:55:01 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Arda Unal (ardaunal)

<details>
<summary>Changes</summary>

Enable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count

This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread:

[Speculative LICM?](https://discourse.llvm.org/t/speculative-licm/80977)

This patch in particular does the following:

1. Relaxes the original constraint for hoisting that only hoists ops without any side effects. This patch also allows the ops with only read side effects to be hoisted into an scf.if guard only if every op in the loop or its nested regions is side-effect free or has only read side effects. This scf.if guard wraps the original scf.for and checks for **trip_count > 0**. 
2. To support this, two new interface methods are added to **LoopLikeInterface**: _wrapInTripCountCheck_ and _unwrapTripCountCheck_. Implementation starts with wrapping the scf.for loop into scf.if guard using  _wrapInTripCountCheck_ and if there is no op hoisted into the this guard after we are done processing the worklist, it unwraps the guard by calling _unwrapTripCountCheck_. 

---

Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120302.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+2-1) 
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+20) 
- (modified) mlir/include/mlir/Interfaces/SideEffectInterfaces.h (+4) 
- (modified) mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h (+9-3) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+79-3) 
- (modified) mlir/lib/Interfaces/SideEffectInterfaces.cpp (+7) 
- (modified) mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (+85-11) 
- (modified) mlir/test/Transforms/loop-invariant-code-motion.mlir (+105) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+44) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
+    InterfaceMethod<[{
+        Wraps the loop into a trip-count check.
+      }],
+      /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+      /*methodName=*/"wrapInTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return ::mlir::failure();"
+    >,
+    InterfaceMethod<[{
+        Unwraps the trip-count check.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"unwrapTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
     InterfaceMethod<[{
         Promotes the loop body to its containing block if the loop is known to
         have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
 ///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
 /// }
 /// ```
 ///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
 ///   `value.getParentRegion()->isProperAncestor(region)`.
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-///   moved of the given region, e.g. if it is side-effect free.
+///   moved of the given region, e.g. if it is side-effect free or has only read
+///   side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
 /// - `moveOutOfRegion` moves the operation out of the given region. A common
 ///   implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+///   this check.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion);
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..148617c84547c7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+  auto lowerBound = this->getLowerBound();
+  auto upperBound = this->getUpperBound();
+  auto step = this->getStep();
+  auto initArgs = this->getInitArgs();
+  auto results = this->getResults();
+  auto loc = this->getLoc();
+
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPointAfter(this->getOperation());
+
+  // Form the trip count calculation
+  auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+  auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+  Value zero;
+  if (upperBound.getType().isIndex()) {
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  } else {
+    zero = rewriter.create<arith::ConstantIntOp>(
+        loc, 0,
+        /*width=*/
+        upperBound.getType().getIntOrFloatBitWidth());
+  }
+  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+                                               ceilDivSIOp, zero);
+  scf::YieldOp yieldInThen;
+  // Create the trip-count check
+  auto ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        yieldInThen = builder.create<scf::YieldOp>(loc, results);
+      },
+      [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, initArgs);
+      });
+
+  for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+  // Move the scf.for into the then block
+  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+  return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+  auto ifOp = (*this)->getParentRegion()->getParentOp();
+  if (!isa<scf::IfOp>(ifOp))
+    return failure();
+
+  auto wrappedForOp = this->getOperation();
+
+  IRRewriter rewriter(ifOp->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(ifOp);
+
+  auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+  auto ceilDivSIOp = cmpOp->getOperand(0).getDefiningOp();
+  auto zero = cmpOp->getOperand(1).getDefiningOp();
+  auto subOp = ceilDivSIOp->getOperand(0).getDefiningOp();
+  if (!isa<arith::CmpIOp>(cmpOp) || !isa<arith::CeilDivSIOp>(ceilDivSIOp) ||
+      !isa<arith::SubIOp>(subOp))
+    return failure();
+
+  rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+  rewriter.eraseOp(ifOp);
+  rewriter.eraseOp(cmpOp);
+  rewriter.eraseOp(zero);
+  rewriter.eraseOp(ceilDivSIOp);
+  rewriter.eraseOp(subOp);
+  return success();
+}
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3474,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
   return wouldOpBeTriviallyDeadImpl(op);
 }
 
+bool mlir::hasOnlyReadEffect(Operation *op) {
+  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+    return memEffects.onlyHasEffect<MemoryEffects::Read>();
+  }
+  return false;
+}
+
 bool mlir::isMemoryEffectFree(Operation *op) {
   if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
     if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..1bdc74dc2a170a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
 }
 
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(OpOperand &)> condition) {
+  auto walkFn = [&](Operation *child) {
+    for (OpOperand &operand : child->getOpOperands()) {
+      if (!condition(operand))
+        return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  };
+  return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(Value)> definedOutsideGuard) {
+  return dependsOnGuarded(op, [&](OpOperand &operand) {
+    return definedOutsideGuard(operand.get());
+  });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+  for (auto &region : loop->getRegions()) {
+    for (auto &block : region.getBlocks()) {
+      for (Operation &op : block.getOperations()) {
+        if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+          return false;
+      }
+    }
+  }
+  return true;
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
+    auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+        loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+    size_t numMovedWithoutGuard = 0;
+
+    FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+    Region *loopRegion = region;
+    auto isLoopWrapped = false;
+    if (succeeded(ifOpAndRegion)) {
+      loopRegion = ifOpAndRegion->second;
+      isLoopWrapped = true;
+    }
+
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
-    for (Operation &op : region->getOps())
+    for (Operation &op : loopRegion->getOps())
       worklist.push(&op);
 
     auto definedOutside = [&](Value value) {
-      return isDefinedOutsideRegion(value, region);
+      return isDefinedOutsideRegion(value, loopRegion);
+    };
+
+    auto definedOutsideGuard = [&](Value value) {
+      return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
     };
 
     while (!worklist.empty()) {
       Operation *op = worklist.front();
       worklist.pop();
       // Skip ops that have already been moved. Check if the op can be hoisted.
-      if (op->getParentRegion() != region)
+      if (op->getParentRegion() != loopRegion)
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
-      if (!shouldMoveOutOfRegion(op, region) ||
+
+      if (!shouldMoveOutOfRegion(op, loopRegion) ||
           !canBeHoisted(op, definedOutside))
         continue;
+      // Can only hoist pure ops (side-effect free) when there is an op with
+      // write side effects in the loop
+      if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+        continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-      moveOutOfRegion(op, region);
+
+      auto moveWithoutGuard = isMemoryEffectFree(op) &&
+                              !dependsOnGuarded(op, definedOutsideGuard) &&
+                              isLoopWrapped;
+      numMovedWithoutGuard += moveWithoutGuard;
+
+      moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+                                           : loopRegion);
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
       // top-level of the loop body.
       for (Operation *user : op->getUsers())
-        if (user->getParentRegion() == region)
+        if (user->getParentRegion() == loopRegion)
           worklist.push(user);
     }
+
+    // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+    if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+      auto tripCountCheckUnwrapped = unwrapGuard();
+      if (failed(tripCountCheckUnwrapped))
+        llvm_unreachable("Should not fail unwrapping trip-count check");
+    }
   }
 
   return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
       loopLike.getLoopRegions(),
-      [&](Value value, Region *) {
-        return loopLike.isDefinedOutsideOfLoop(value);
+      [&](Value value, Region *region) {
+        return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isMemoryEffectFree(op) && isSpeculatable(op);
+        return isSpeculatable(op) &&
+               (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+      },
+      [&]() { return loopLike.wrapInTripCountCheck(); },
+      [&](Operation *op, Region *region) {
+        op->moveBefore(region->getParentOp());
       },
-      [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+      [&]() { return loopLike.unwrapTripCountCheck(); });
 }
 
 namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..6f5cc60c59252c 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,111 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
   return
 }
 
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.subi
+  // CHECK: arith.ceildivsi
+  // CHECK: arith.constant
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.subi
+  // CHECK: arith.ceildivsi
+  // CHECK: arith.constant
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: arith.addi
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add = arith.addi %always_read, %cst_0 : i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %add : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: scf.for
+  // CHECK: scf.if
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/120302


More information about the Mlir-commits mailing list