[Mlir-commits] [mlir] 36d8e70 - [mlir] Enable LICM for ops with only read side effects in scf.for (#120302)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 15:49:00 PST 2025


Author: Arda Unal
Date: 2025-02-11T15:48:57-08:00
New Revision: 36d8e7056ef08d0f51e5b1e928274a8ca7dadeb5

URL: https://github.com/llvm/llvm-project/commit/36d8e7056ef08d0f51e5b1e928274a8ca7dadeb5
DIFF: https://github.com/llvm/llvm-project/commit/36d8e7056ef08d0f51e5b1e928274a8ca7dadeb5.diff

LOG: [mlir] Enable LICM for ops with only read side effects in scf.for (#120302)

Enable ops with only read side effects in scf.for to be hoisted with a
scf.if guard that checks against the trip count

This patch takes a step towards a less conservative LICM in MLIR as
discussed in the following discourse thread:

[Speculative LICM?](https://discourse.llvm.org/t/speculative-licm/80977)

This patch in particular does the following:

1. Relaxes the original constraint for hoisting that only hoists ops
without any side effects. This patch also allows the ops with only read
side effects to be hoisted into an scf.if guard only if every op in the
loop or its nested regions is side-effect free or has only read side
effects. This scf.if guard wraps the original scf.for and checks for
**trip_count > 0**.
2. To support this, two new interface methods are added to
**LoopLikeInterface**: _wrapInTripCountCheck_ and
_unwrapTripCountCheck_. Implementation starts with wrapping the scf.for
loop into scf.if guard using _wrapInTripCountCheck_ and if there is no
op hoisted into the this guard after we are done processing the
worklist, it unwraps the guard by calling _unwrapTripCountCheck_.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Interfaces/LoopLikeInterface.td
    mlir/include/mlir/Interfaces/SideEffectInterfaces.h
    mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Interfaces/SideEffectInterfaces.cpp
    mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
    mlir/test/Transforms/loop-invariant-code-motion.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6f408b3c924de..ac648aca087f3 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
     and then lowered to some final target like LLVM or SPIR-V.
   }];
 
-  let dependentDialects = ["arith::ArithDialect"];
+  let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
 }
 
 // Base class for SCF dialect ops.
@@ -138,7 +138,9 @@ def ForOp : SCF_Op<"for",
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
+        "moveOutOfLoopWithGuard",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e..c4ff0b485565e 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,18 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
+    InterfaceMethod<[{
+        Moves the given loop-invariant operation out of the loop with a
+        trip-count guard.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"moveOutOfLoopWithGuard",
+      /*args=*/(ins "::mlir::Operation *":$op),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return;
+      }]
+    >,
     InterfaceMethod<[{
         Promotes the loop body to its containing block if the loop is known to
         have a single iteration. Returns "success" if the promotion was

diff  --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f..f8142567c6060 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns true if the given operation is free of memory effects or has only
+/// read effect.
+bool isMemoryEffectFreeOrOnlyRead(Operation *op);
+
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
 ///

diff  --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e8..83a70abe0a5ed 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -47,16 +47,19 @@ class Value;
 ///   }
 /// }
 /// ```
-///
-/// Users must supply three callbacks.
+/// Users must supply four callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
 ///   `value.getParentRegion()->isProperAncestor(region)`.
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-///   moved of the given region, e.g. if it is side-effect free.
-/// - `moveOutOfRegion` moves the operation out of the given region. A common
-///   implementation might be: `op->moveBefore(region->getParentOp())`.
+///   moved of the given region, e.g. if it is side-effect free or has only read
+///   side effects.
+/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
+///   without a guard. A common implementation might be:
+///   `op->moveBefore(region->getParentOp())`.
+/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
+///   with a guard.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +69,8 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion);
+    function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
+    function_ref<void(Operation *)> moveOutOfRegionWithGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83ae79ce48266..85729486f63d5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -395,6 +396,40 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+/// Moves the op out of the loop with a guard that checks if the loop has at
+/// least one iteration.
+void ForOp::moveOutOfLoopWithGuard(Operation *op) {
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(this->getOperation());
+  Location loc = this->getLoc();
+  arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::ult, this->getLowerBound(),
+      this->getUpperBound());
+  // Create the trip-count check.
+  scf::YieldOp thenYield;
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
+      },
+      [&](OpBuilder &builder, Location loc) {
+        SmallVector<Value> poisonResults;
+        poisonResults.reserve(op->getResults().size());
+        for (Type type : op->getResults().getTypes()) {
+          ub::PoisonOp poisonOp =
+              rewriter.create<ub::PoisonOp>(loc, type, nullptr);
+          poisonResults.push_back(poisonOp);
+        }
+        builder.create<scf::YieldOp>(loc, poisonResults);
+      });
+  for (auto [opResult, ifOpResult] :
+       llvm::zip_equal(op->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
+  // Move the op into the then block.
+  rewriter.moveOpBefore(op, thenYield);
+}
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3394,9 +3429,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.

diff  --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index 59fd19310cea5..7418bc96815ec 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include <utility>
 
@@ -370,6 +371,17 @@ mlir::getEffectsRecursively(Operation *rootOp) {
   return effects;
 }
 
+bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
+  std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
+      getEffectsRecursively(op);
+  if (!effects)
+    return false;
+  return llvm::all_of(*effects,
+                      [&](const MemoryEffects::EffectInstance &effect) {
+                        return isa<MemoryEffects::Read>(effect.getEffect());
+                      });
+}
+
 bool mlir::isSpeculatable(Operation *op) {
   auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
   if (!conditionallySpeculatable)

diff  --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78..094db01176a28 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -60,13 +60,18 @@ size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+    function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
+    function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
+    bool anyOpHoistedWithGuard = false;
+    bool loopSideEffectFreeOrHasOnlyReadSideEffect =
+        isMemoryEffectFreeOrOnlyRead(region->getParentOp());
+
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
     for (Operation &op : region->getOps())
@@ -84,12 +89,26 @@ size_t mlir::moveLoopInvariantCode(
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
+
       if (!shouldMoveOutOfRegion(op, region) ||
           !canBeHoisted(op, definedOutside))
         continue;
+      // Can only hoist pure ops (side-effect free) when there is an op with
+      // write and/or unknown side effects in the loop.
+      if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+        continue;
 
-      LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-      moveOutOfRegion(op, region);
+      bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
+      if (moveWithoutGuard) {
+        LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
+                                << " without guard\n");
+        moveOutOfRegionWithoutGuard(op);
+      } else {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Moving loop-invariant op: " << *op << " with guard\n");
+        moveOutOfRegionWithGuard(op);
+        anyOpHoistedWithGuard = true;
+      }
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
@@ -106,13 +125,14 @@ size_t mlir::moveLoopInvariantCode(
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
       loopLike.getLoopRegions(),
-      [&](Value value, Region *) {
-        return loopLike.isDefinedOutsideOfLoop(value);
+      [&](Value value, Region *region) {
+        return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isMemoryEffectFree(op) && isSpeculatable(op);
+        return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
       },
-      [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+      [&](Operation *op) { loopLike.moveOutOfLoop(op); },
+      [&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
 }
 
 namespace {

diff  --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 5133c14414c97..66a34abca5450 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -714,6 +714,150 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
   return
 }
 
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+  // CHECK-NEXT: scf.if %[[CMP]]
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read : i32
+    scf.yield %sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
+func.func @test_speculatable_op_with_read_side_effect_multiple_result_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+  // CHECK-NEXT: scf.if %[[CMP]]
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK-NEXT: ub.poison : f32
+  // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read:2 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> (i32, f32)
+    %i_cast = arith.index_cast %i: index to i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read#0 : i32
+    scf.yield %sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
+  // CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+  // CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+  // CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
+  // CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+  // CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+  // CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
+  // CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
+  // CHECK: else
+  // CHECK-NEXT: ub.poison : i32
+  // CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read_0 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add_0 = arith.addi %always_speculate, %only_read_0 : i32
+    %only_read_1 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add_1 = arith.addi %add_0, %only_read_1 : i32
+    %i_cast = arith.index_cast %i: index to i32
+    %sum = arith.addi %add_1, %i_cast : i32
+    scf.yield %sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read : i32
+    %write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: scf.for
+  // CHECK: scf.if
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always_speculate = "test.always_speculatable_op"() : () -> i32
+    %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %add = arith.addi %acc, %i_cast : i32
+    %sum = arith.addi %add, %only_read : i32
+    scf.for %j = %lb to %ub step %step {
+      %eq42 = arith.cmpi eq, %j, %ind_42 : index
+      scf.if %eq42 {
+        %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+      }
+    }
+    scf.yield %sum : i32
+  }
+  return %sum_result : i32
+}
+
 // -----
 
 func.func @speculate_tensor_dim_unknown_rank_unknown_dim(

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d..193b682c034f6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2947,6 +2947,50 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
   let regions = (region SizedRegion<1>:$body);
 }
 
+def SpeculatableOpWithReadSideEffect : TEST_Op<"speculatable_op_with_memread",
+    [ConditionallySpeculatable, MemoryEffects<[MemRead]>]> {
+  let description = [{
+    Op used to test LICM conditional speculation.  This op can always be
+    speculatively executed and has only memory read effect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let extraClassDeclaration = [{
+    ::mlir::Speculation::Speculatability getSpeculatability();
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::Speculation::Speculatability
+    SpeculatableOpWithReadSideEffect::getSpeculatability() {
+      return ::mlir::Speculation::Speculatable;
+    }
+  }];
+}
+
+def SpeculatableOpWithWriteSideEffect : TEST_Op<"speculatable_op_with_memwrite",
+    [ConditionallySpeculatable, MemoryEffects<[MemWrite]>]> {
+  let description = [{
+    Op used to test LICM conditional speculation.  This op can always be
+    speculatively executed and has only memory read effect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let extraClassDeclaration = [{
+    ::mlir::Speculation::Speculatability getSpeculatability();
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::Speculation::Speculatability
+    SpeculatableOpWithWriteSideEffect::getSpeculatability() {
+      return ::mlir::Speculation::Speculatable;
+    }
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // Test CSE
 //===---------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list