[Mlir-commits] [mlir] Enable LICM for ops with only read side effects in scf.for (PR #120302)

Arda Unal llvmlistbot at llvm.org
Thu Dec 19 14:55:49 PST 2024


https://github.com/ardaunal updated https://github.com/llvm/llvm-project/pull/120302

>From 0e596fdeb855ec27a4d78918971af819c0769825 Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Tue, 17 Dec 2024 12:29:21 -0800
Subject: [PATCH] Enable LICM for ops with read side effects in scf.for wrapped
 by a guard

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  3 +-
 .../mlir/Interfaces/LoopLikeInterface.td      | 20 ++++
 .../mlir/Interfaces/SideEffectInterfaces.h    |  4 +
 .../Transforms/LoopInvariantCodeMotionUtils.h | 12 ++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 59 ++++++++++-
 mlir/lib/Interfaces/SideEffectInterfaces.cpp  |  7 ++
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 96 +++++++++++++++---
 .../loop-invariant-code-motion.mlir           | 99 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 44 +++++++++
 9 files changed, 326 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
+    InterfaceMethod<[{
+        Wraps the loop into a trip-count check.
+      }],
+      /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+      /*methodName=*/"wrapInTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return ::mlir::failure();"
+    >,
+    InterfaceMethod<[{
+        Unwraps the trip-count check.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"unwrapTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
     InterfaceMethod<[{
         Promotes the loop body to its containing block if the loop is known to
         have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
 ///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
 /// }
 /// ```
 ///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
 ///   `value.getParentRegion()->isProperAncestor(region)`.
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-///   moved of the given region, e.g. if it is side-effect free.
+///   moved of the given region, e.g. if it is side-effect free or has only read
+///   side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
 /// - `moveOutOfRegion` moves the operation out of the given region. A common
 ///   implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+///   this check.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion);
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..d246ebdaaea2f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPointAfter(this->getOperation());
+
+  auto loc = this->getLoc();
+  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+                                               this->getUpperBound(),
+                                               this->getLowerBound());
+  scf::YieldOp yieldInThen;
+  // Create the trip-count check.
+  auto ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
+      },
+      [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, this->getInitArgs());
+      });
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(this->getResults(), ifOp.getResults()))
+    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+  // Move the scf.for into the then block.
+  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+  return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+  auto ifOp = (*this)->getParentRegion()->getParentOp();
+  if (!isa<scf::IfOp>(ifOp))
+    return failure();
+
+  IRRewriter rewriter(ifOp->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(ifOp);
+
+  auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+  if (!isa<arith::CmpIOp>(cmpOp))
+    return failure();
+
+  auto wrappedForOp = this->getOperation();
+  rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+  rewriter.eraseOp(ifOp);
+  rewriter.eraseOp(cmpOp);
+  return success();
+}
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3451,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
   return wouldOpBeTriviallyDeadImpl(op);
 }
 
+bool mlir::hasOnlyReadEffect(Operation *op) {
+  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+    return memEffects.onlyHasEffect<MemoryEffects::Read>();
+  }
+  return false;
+}
+
 bool mlir::isMemoryEffectFree(Operation *op) {
   if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
     if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..b0bf86c0c8e878 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
 }
 
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(OpOperand &)> condition) {
+  auto walkFn = [&](Operation *child) {
+    for (OpOperand &operand : child->getOpOperands()) {
+      if (!condition(operand))
+        return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  };
+  return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(Value)> definedOutsideGuard) {
+  return dependsOnGuarded(op, [&](OpOperand &operand) {
+    return definedOutsideGuard(operand.get());
+  });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+  for (Region &region : loop->getRegions()) {
+    for (Block &block : region.getBlocks()) {
+      for (Operation &op : block.getOperations()) {
+        if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+          return false;
+      }
+    }
+  }
+  return true;
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
+    auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+        loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+    size_t numMovedWithoutGuard = 0;
+
+    FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+    Region *loopRegion = region;
+    auto isLoopWrapped = false;
+    if (succeeded(ifOpAndRegion)) {
+      loopRegion = ifOpAndRegion->second;
+      isLoopWrapped = true;
+    }
+
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
-    for (Operation &op : region->getOps())
+    for (Operation &op : loopRegion->getOps())
       worklist.push(&op);
 
     auto definedOutside = [&](Value value) {
-      return isDefinedOutsideRegion(value, region);
+      return isDefinedOutsideRegion(value, loopRegion);
+    };
+
+    auto definedOutsideGuard = [&](Value value) {
+      return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
     };
 
     while (!worklist.empty()) {
       Operation *op = worklist.front();
       worklist.pop();
       // Skip ops that have already been moved. Check if the op can be hoisted.
-      if (op->getParentRegion() != region)
+      if (op->getParentRegion() != loopRegion)
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
-      if (!shouldMoveOutOfRegion(op, region) ||
+
+      if (!shouldMoveOutOfRegion(op, loopRegion) ||
           !canBeHoisted(op, definedOutside))
         continue;
+      // Can only hoist pure ops (side-effect free) when there is an op with
+      // write side effects in the loop.
+      if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+        continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-      moveOutOfRegion(op, region);
+
+      auto moveWithoutGuard = isMemoryEffectFree(op) &&
+                              !dependsOnGuarded(op, definedOutsideGuard) &&
+                              isLoopWrapped;
+      numMovedWithoutGuard += moveWithoutGuard;
+
+      moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+                                           : loopRegion);
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
       // top-level of the loop body.
       for (Operation *user : op->getUsers())
-        if (user->getParentRegion() == region)
+        if (user->getParentRegion() == loopRegion)
           worklist.push(user);
     }
+
+    // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+    if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+      auto tripCountCheckUnwrapped = unwrapGuard();
+      if (failed(tripCountCheckUnwrapped))
+        llvm_unreachable("Should not fail unwrapping trip-count check");
+    }
   }
 
   return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
       loopLike.getLoopRegions(),
-      [&](Value value, Region *) {
-        return loopLike.isDefinedOutsideOfLoop(value);
+      [&](Value value, Region *region) {
+        return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isMemoryEffectFree(op) && isSpeculatable(op);
+        return isSpeculatable(op) &&
+               (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+      },
+      [&]() { return loopLike.wrapInTripCountCheck(); },
+      [&](Operation *op, Region *region) {
+        op->moveBefore(region->getParentOp());
       },
-      [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+      [&]() { return loopLike.unwrapTripCountCheck(); });
 }
 
 namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..f2c4f09e1e7b22 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,105 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
   return
 }
 
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: arith.addi
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add = arith.addi %always_read, %cst_0 : i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %add : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: scf.for
+  // CHECK: scf.if
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.for %j = %lb to %ub step %step {
+      %eq42 = arith.cmpi eq, %j, %ind_42 : index
+      scf.if %eq42 {
+        %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+      }
+    }
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
 // -----
 
 func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d24d52f356d88f..1b834a0a9266b2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2916,6 +2916,50 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
   let regions = (region SizedRegion<1>:$body);
 }
 
+def SpeculatableOpWithReadSideEffect : TEST_Op<"speculatable_op_with_memread",
+    [ConditionallySpeculatable, MemoryEffects<[MemRead]>]> {
+  let description = [{
+    Op used to test LICM conditional speculation.  This op can always be
+    speculatively executed and has only memory read effect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let extraClassDeclaration = [{
+    ::mlir::Speculation::Speculatability getSpeculatability();
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::Speculation::Speculatability
+    SpeculatableOpWithReadSideEffect::getSpeculatability() {
+      return ::mlir::Speculation::Speculatable;
+    }
+  }];
+}
+
+def SpeculatableOpWithWriteSideEffect : TEST_Op<"speculatable_op_with_memwrite",
+    [ConditionallySpeculatable, MemoryEffects<[MemWrite]>]> {
+  let description = [{
+    Op used to test LICM conditional speculation.  This op can always be
+    speculatively executed and has only memory read effect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let extraClassDeclaration = [{
+    ::mlir::Speculation::Speculatability getSpeculatability();
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::Speculation::Speculatability
+    SpeculatableOpWithWriteSideEffect::getSpeculatability() {
+      return ::mlir::Speculation::Speculatable;
+    }
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // Test CSE
 //===---------------------------------------------------------------------===//



More information about the Mlir-commits mailing list