[Mlir-commits] [mlir] Enable LICM for ops with only read side effects in scf.for (PR #120302)
Arda Unal
llvmlistbot at llvm.org
Mon Jan 6 15:45:21 PST 2025
https://github.com/ardaunal updated https://github.com/llvm/llvm-project/pull/120302
>From 0e596fdeb855ec27a4d78918971af819c0769825 Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Tue, 17 Dec 2024 12:29:21 -0800
Subject: [PATCH 1/3] Enable LICM for ops with read side effects in scf.for
wrapped by a guard
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 3 +-
.../mlir/Interfaces/LoopLikeInterface.td | 20 ++++
.../mlir/Interfaces/SideEffectInterfaces.h | 4 +
.../Transforms/LoopInvariantCodeMotionUtils.h | 12 ++-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 59 ++++++++++-
mlir/lib/Interfaces/SideEffectInterfaces.cpp | 7 ++
.../Utils/LoopInvariantCodeMotionUtils.cpp | 96 +++++++++++++++---
.../loop-invariant-code-motion.mlir | 99 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 44 +++++++++
9 files changed, 326 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
+ "wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
+ ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
+ InterfaceMethod<[{
+ Wraps the loop into a trip-count check.
+ }],
+ /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+ /*methodName=*/"wrapInTripCountCheck",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return ::mlir::failure();"
+ >,
+ InterfaceMethod<[{
+ Unwraps the trip-count check.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"unwrapTripCountCheck",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
/// }
/// ```
///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-/// moved of the given region, e.g. if it is side-effect free.
+/// moved of the given region, e.g. if it is side-effect free or has only read
+/// side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+/// this check.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<void(Operation *, Region *)> moveOutOfRegion);
+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
+ function_ref<LogicalResult()> unwrapGuard);
/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..d246ebdaaea2f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+
+ IRRewriter rewriter(this->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPointAfter(this->getOperation());
+
+ auto loc = this->getLoc();
+ auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+ this->getUpperBound(),
+ this->getLowerBound());
+ scf::YieldOp yieldInThen;
+ // Create the trip-count check.
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, cmpIOp,
+ [&](OpBuilder &builder, Location loc) {
+ yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, this->getInitArgs());
+ });
+
+ for (auto [forOpResult, ifOpResult] :
+ llvm::zip(this->getResults(), ifOp.getResults()))
+ rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+ // Move the scf.for into the then block.
+ rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+ return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+ auto ifOp = (*this)->getParentRegion()->getParentOp();
+ if (!isa<scf::IfOp>(ifOp))
+ return failure();
+
+ IRRewriter rewriter(ifOp->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPoint(ifOp);
+
+ auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+ if (!isa<arith::CmpIOp>(cmpOp))
+ return failure();
+
+ auto wrappedForOp = this->getOperation();
+ rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+ for (auto [forOpResult, ifOpResult] :
+ llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+ rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+ rewriter.eraseOp(ifOp);
+ rewriter.eraseOp(cmpOp);
+ return success();
+}
+
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3451,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
return wouldOpBeTriviallyDeadImpl(op);
}
+bool mlir::hasOnlyReadEffect(Operation *op) {
+ if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+ return memEffects.onlyHasEffect<MemoryEffects::Read>();
+ }
+ return false;
+}
+
bool mlir::isMemoryEffectFree(Operation *op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..b0bf86c0c8e878 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}
+static bool dependsOnGuarded(Operation *op,
+ function_ref<bool(OpOperand &)> condition) {
+ auto walkFn = [&](Operation *child) {
+ for (OpOperand &operand : child->getOpOperands()) {
+ if (!condition(operand))
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ };
+ return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+ function_ref<bool(Value)> definedOutsideGuard) {
+ return dependsOnGuarded(op, [&](OpOperand &operand) {
+ return definedOutsideGuard(operand.get());
+ });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+ for (Region ®ion : loop->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ for (Operation &op : block.getOperations()) {
+ if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
+ function_ref<LogicalResult()> unwrapGuard) {
size_t numMoved = 0;
for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");
+ auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+ loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+ size_t numMovedWithoutGuard = 0;
+
+ FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+ Region *loopRegion = region;
+ auto isLoopWrapped = false;
+ if (succeeded(ifOpAndRegion)) {
+ loopRegion = ifOpAndRegion->second;
+ isLoopWrapped = true;
+ }
+
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
- for (Operation &op : region->getOps())
+ for (Operation &op : loopRegion->getOps())
worklist.push(&op);
auto definedOutside = [&](Value value) {
- return isDefinedOutsideRegion(value, region);
+ return isDefinedOutsideRegion(value, loopRegion);
+ };
+
+ auto definedOutsideGuard = [&](Value value) {
+ return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
};
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
- if (op->getParentRegion() != region)
+ if (op->getParentRegion() != loopRegion)
continue;
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
- if (!shouldMoveOutOfRegion(op, region) ||
+
+ if (!shouldMoveOutOfRegion(op, loopRegion) ||
!canBeHoisted(op, definedOutside))
continue;
+ // Can only hoist pure ops (side-effect free) when there is an op with
+ // write side effects in the loop.
+ if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+ continue;
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
- moveOutOfRegion(op, region);
+
+ auto moveWithoutGuard = isMemoryEffectFree(op) &&
+ !dependsOnGuarded(op, definedOutsideGuard) &&
+ isLoopWrapped;
+ numMovedWithoutGuard += moveWithoutGuard;
+
+ moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+ : loopRegion);
++numMoved;
// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
- if (user->getParentRegion() == region)
+ if (user->getParentRegion() == loopRegion)
worklist.push(user);
}
+
+ // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+ if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+ auto tripCountCheckUnwrapped = unwrapGuard();
+ if (failed(tripCountCheckUnwrapped))
+ llvm_unreachable("Should not fail unwrapping trip-count check");
+ }
}
return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
- [&](Value value, Region *) {
- return loopLike.isDefinedOutsideOfLoop(value);
+ [&](Value value, Region *region) {
+ return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
- return isMemoryEffectFree(op) && isSpeculatable(op);
+ return isSpeculatable(op) &&
+ (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+ },
+ [&]() { return loopLike.wrapInTripCountCheck(); },
+ [&](Operation *op, Region *region) {
+ op->moveBefore(region->getParentOp());
},
- [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+ [&]() { return loopLike.unwrapTripCountCheck(); });
}
namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..f2c4f09e1e7b22 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,105 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
return
}
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: arith.cmpi
+ // CHECK-NEXT: test.always_speculatable_op
+ // CHECK-NEXT: scf.if
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: arith.cmpi
+ // CHECK-NEXT: test.always_speculatable_op
+ // CHECK-NEXT: scf.if
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK-NEXT: arith.addi
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %add = arith.addi %always_read, %cst_0 : i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %add : i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK: test.speculatable_op_with_memread
+ // CHECK: test.speculatable_op_with_memwrite
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK: test.speculatable_op_with_memread
+ // CHECK: scf.for
+ // CHECK: scf.if
+ // CHECK: test.speculatable_op_with_memwrite
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ scf.for %j = %lb to %ub step %step {
+ %eq42 = arith.cmpi eq, %j, %ind_42 : index
+ scf.if %eq42 {
+ %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+ }
+ }
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
// -----
func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d24d52f356d88f..1b834a0a9266b2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2916,6 +2916,50 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
let regions = (region SizedRegion<1>:$body);
}
+def SpeculatableOpWithReadSideEffect : TEST_Op<"speculatable_op_with_memread",
+ [ConditionallySpeculatable, MemoryEffects<[MemRead]>]> {
+ let description = [{
+ Op used to test LICM conditional speculation. This op can always be
+ speculatively executed and has only memory read effect.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$inputs);
+ let results = (outs Variadic<AnyType>:$outputs);
+
+ let extraClassDeclaration = [{
+ ::mlir::Speculation::Speculatability getSpeculatability();
+ }];
+
+ let extraClassDefinition = [{
+ ::mlir::Speculation::Speculatability
+ SpeculatableOpWithReadSideEffect::getSpeculatability() {
+ return ::mlir::Speculation::Speculatable;
+ }
+ }];
+}
+
+def SpeculatableOpWithWriteSideEffect : TEST_Op<"speculatable_op_with_memwrite",
+ [ConditionallySpeculatable, MemoryEffects<[MemWrite]>]> {
+ let description = [{
+ Op used to test LICM conditional speculation. This op can always be
+ speculatively executed and has only memory read effect.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$inputs);
+ let results = (outs Variadic<AnyType>:$outputs);
+
+ let extraClassDeclaration = [{
+ ::mlir::Speculation::Speculatability getSpeculatability();
+ }];
+
+ let extraClassDefinition = [{
+ ::mlir::Speculation::Speculatability
+ SpeculatableOpWithWriteSideEffect::getSpeculatability() {
+ return ::mlir::Speculation::Speculatable;
+ }
+ }];
+}
+
//===---------------------------------------------------------------------===//
// Test CSE
//===---------------------------------------------------------------------===//
>From c5328108c35037fb408f1d4830231c6b18cbc74d Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Thu, 19 Dec 2024 17:01:55 -0800
Subject: [PATCH 2/3] fixup! Enable LICM for ops with read side effects in
scf.for wrapped by a guard
---
mlir/lib/Interfaces/SideEffectInterfaces.cpp | 17 +++++++++++++++--
.../Utils/LoopInvariantCodeMotionUtils.cpp | 2 +-
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index f45d5f3d227407..ccbcb96eec58f5 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -308,9 +308,22 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
bool mlir::hasOnlyReadEffect(Operation *op) {
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
- return memEffects.onlyHasEffect<MemoryEffects::Read>();
+ if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+ return memEffects.onlyHasEffect<MemoryEffects::Read>();
+ } else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
+ // Otherwise, if the op does not implement the memory effect interface and
+ // it does not have recursive side effects, then it cannot be known that the
+ // op is moveable.
+ return false;
}
- return false;
+
+ // Recurse into the regions and ensure that all nested ops are memory effect
+ // free.
+ for (Region ®ion : op->getRegions())
+ for (Operation &op : region.getOps())
+ if (!hasOnlyReadEffect(&op))
+ return false;
+ return true;
}
bool mlir::isMemoryEffectFree(Operation *op) {
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index b0bf86c0c8e878..1d8fb568635b3f 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -139,7 +139,7 @@ size_t mlir::moveLoopInvariantCode(
!canBeHoisted(op, definedOutside))
continue;
// Can only hoist pure ops (side-effect free) when there is an op with
- // write side effects in the loop.
+ // write and/or unknown side effects in the loop.
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
continue;
>From 8a38077cfb16f5ea4c7ab14a82d3891ae67d5d74 Mon Sep 17 00:00:00 2001
From: ardau <ardau at meta.com>
Date: Mon, 6 Jan 2025 15:41:53 -0800
Subject: [PATCH 3/3] fixup! fixup! Enable LICM for ops with read side effects
in scf.for wrapped by a guard
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 3 +-
.../mlir/Interfaces/LoopLikeInterface.td | 20 +---
.../mlir/Interfaces/SideEffectInterfaces.h | 6 +-
.../Transforms/LoopInvariantCodeMotionUtils.h | 18 ++-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 69 ++++-------
mlir/lib/Interfaces/SideEffectInterfaces.cpp | 30 ++---
.../Utils/LoopInvariantCodeMotionUtils.cpp | 102 ++++-------------
.../loop-invariant-code-motion.mlir | 107 +++++++++++++-----
8 files changed, 154 insertions(+), 201 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b54df8e3ef313d..0bac8bfbbb7499 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
and then lowered to some final target like LLVM or SPIR-V.
}];
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
}
// Base class for SCF dialect ops.
@@ -138,6 +138,7 @@ def ForOp : SCF_Op<"for",
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
+ "moveOutOfLoopWithGuard",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 831830130b0ddc..c4ff0b485565e0 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -80,23 +80,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
InterfaceMethod<[{
- Wraps the loop into a trip-count check.
+ Moves the given loop-invariant operation out of the loop with a
+ trip-count guard.
}],
- /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
- /*methodName=*/"wrapInTripCountCheck",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return ::mlir::failure();"
- >,
- InterfaceMethod<[{
- Unwraps the trip-count check.
- }],
- /*retTy=*/"::llvm::LogicalResult",
- /*methodName=*/"unwrapTripCountCheck",
- /*args=*/(ins),
+ /*retTy=*/"void",
+ /*methodName=*/"moveOutOfLoopWithGuard",
+ /*args=*/(ins "::mlir::Operation *":$op),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return ::mlir::failure();
+ return;
}]
>,
InterfaceMethod<[{
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index 1a7f66e2234949..f8142567c60607 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,9 +433,9 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);
-/// Returns true if the given operation implements `MemoryEffectOpInterface` and
-/// has only read effects.
-bool hasOnlyReadEffect(Operation *op);
+/// Returns true if the given operation is free of memory effects or has only
+/// read effect.
+bool isMemoryEffectFreeOrOnlyRead(Operation *op);
/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index ae6719abe79c00..83a70abe0a5edf 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -47,8 +47,7 @@ class Value;
/// }
/// }
/// ```
-///
-/// Users must supply five callbacks.
+/// Users must supply four callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
@@ -56,11 +55,11 @@ class Value;
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
/// moved of the given region, e.g. if it is side-effect free or has only read
/// side effects.
-/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
-/// - `moveOutOfRegion` moves the operation out of the given region. A common
-/// implementation might be: `op->moveBefore(region->getParentOp())`.
-/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
-/// this check.
+/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
+/// without a guard. A common implementation might be:
+/// `op->moveBefore(region->getParentOp())`.
+/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
+/// with a guard.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -70,9 +69,8 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
- function_ref<void(Operation *, Region *)> moveOutOfRegion,
- function_ref<LogicalResult()> unwrapGuard);
+ function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
+ function_ref<void(Operation *)> moveOutOfRegionWithGuard);
/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index d246ebdaaea2f3..0974cc214f99f2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -395,58 +396,38 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
-FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
-
+/// Moves the op out of the loop with a guard that checks if the loop has at
+/// least one iteration.
+void ForOp::moveOutOfLoopWithGuard(Operation *op) {
IRRewriter rewriter(this->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
- rewriter.setInsertionPointAfter(this->getOperation());
-
- auto loc = this->getLoc();
- auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
- this->getUpperBound(),
- this->getLowerBound());
- scf::YieldOp yieldInThen;
+ rewriter.setInsertionPoint(this->getOperation());
+ Location loc = this->getLoc();
+ arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, this->getLowerBound(),
+ this->getUpperBound());
// Create the trip-count check.
- auto ifOp = rewriter.create<scf::IfOp>(
+ scf::YieldOp thenYield;
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(
loc, cmpIOp,
[&](OpBuilder &builder, Location loc) {
- yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
+ thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
},
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, this->getInitArgs());
+ SmallVector<Value> poisonResults;
+ poisonResults.reserve(op->getResults().size());
+ for (Type type : op->getResults().getTypes()) {
+ ub::PoisonOp poisonOp =
+ rewriter.create<ub::PoisonOp>(loc, type, nullptr);
+ poisonResults.push_back(poisonOp);
+ }
+ builder.create<scf::YieldOp>(loc, poisonResults);
});
-
- for (auto [forOpResult, ifOpResult] :
- llvm::zip(this->getResults(), ifOp.getResults()))
- rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
- // Move the scf.for into the then block.
- rewriter.moveOpBefore(this->getOperation(), yieldInThen);
- return std::make_pair(ifOp.getOperation(), &this->getRegion());
-}
-
-LogicalResult ForOp::unwrapTripCountCheck() {
- auto ifOp = (*this)->getParentRegion()->getParentOp();
- if (!isa<scf::IfOp>(ifOp))
- return failure();
-
- IRRewriter rewriter(ifOp->getContext());
- OpBuilder::InsertionGuard insertGuard(rewriter);
- rewriter.setInsertionPoint(ifOp);
-
- auto cmpOp = ifOp->getOperand(0).getDefiningOp();
- if (!isa<arith::CmpIOp>(cmpOp))
- return failure();
-
- auto wrappedForOp = this->getOperation();
- rewriter.moveOpBefore(wrappedForOp, ifOp);
-
- for (auto [forOpResult, ifOpResult] :
- llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
- rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
-
- rewriter.eraseOp(ifOp);
- rewriter.eraseOp(cmpOp);
- return success();
+ for (auto [opResult, ifOpResult] :
+ llvm::zip(op->getResults(), ifOp->getResults()))
+ rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
+ // Move the op into the then block.
+ rewriter.moveOpBefore(op, thenYield);
}
/// Promotes the loop body of a forOp to its containing block if the forOp
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index ccbcb96eec58f5..9d681ba59b8b28 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,26 +306,6 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
return wouldOpBeTriviallyDeadImpl(op);
}
-bool mlir::hasOnlyReadEffect(Operation *op) {
- if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
- if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
- return memEffects.onlyHasEffect<MemoryEffects::Read>();
- } else if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
- // Otherwise, if the op does not implement the memory effect interface and
- // it does not have recursive side effects, then it cannot be known that the
- // op is moveable.
- return false;
- }
-
- // Recurse into the regions and ensure that all nested ops are memory effect
- // free.
- for (Region ®ion : op->getRegions())
- for (Operation &op : region.getOps())
- if (!hasOnlyReadEffect(&op))
- return false;
- return true;
-}
-
bool mlir::isMemoryEffectFree(Operation *op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
@@ -383,6 +363,16 @@ mlir::getEffectsRecursively(Operation *rootOp) {
return effects;
}
+bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
+ std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
+ getEffectsRecursively(op);
+ if (!effects)
+ return false;
+ return std::all_of(effects->begin(), effects->end(), [](auto &effect) {
+ return isa<MemoryEffects::Read>(effect.getEffect());
+ });
+}
+
bool mlir::isSpeculatable(Operation *op) {
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
if (!conditionallySpeculatable)
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 1d8fb568635b3f..094db01176a28c 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,86 +56,41 @@ static bool canBeHoisted(Operation *op,
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}
-static bool dependsOnGuarded(Operation *op,
- function_ref<bool(OpOperand &)> condition) {
- auto walkFn = [&](Operation *child) {
- for (OpOperand &operand : child->getOpOperands()) {
- if (!condition(operand))
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- };
- return op->walk(walkFn).wasInterrupted();
-}
-
-static bool dependsOnGuarded(Operation *op,
- function_ref<bool(Value)> definedOutsideGuard) {
- return dependsOnGuarded(op, [&](OpOperand &operand) {
- return definedOutsideGuard(operand.get());
- });
-}
-
-static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
- for (Region ®ion : loop->getRegions()) {
- for (Block &block : region.getBlocks()) {
- for (Operation &op : block.getOperations()) {
- if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
- return false;
- }
- }
- }
- return true;
-}
-
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
- function_ref<void(Operation *, Region *)> moveOutOfRegion,
- function_ref<LogicalResult()> unwrapGuard) {
+ function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
+ function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
size_t numMoved = 0;
for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");
- auto loopSideEffectFreeOrHasOnlyReadSideEffect =
- loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
-
- size_t numMovedWithoutGuard = 0;
-
- FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
- Region *loopRegion = region;
- auto isLoopWrapped = false;
- if (succeeded(ifOpAndRegion)) {
- loopRegion = ifOpAndRegion->second;
- isLoopWrapped = true;
- }
+ bool anyOpHoistedWithGuard = false;
+ bool loopSideEffectFreeOrHasOnlyReadSideEffect =
+ isMemoryEffectFreeOrOnlyRead(region->getParentOp());
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
- for (Operation &op : loopRegion->getOps())
+ for (Operation &op : region->getOps())
worklist.push(&op);
auto definedOutside = [&](Value value) {
- return isDefinedOutsideRegion(value, loopRegion);
- };
-
- auto definedOutsideGuard = [&](Value value) {
- return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
+ return isDefinedOutsideRegion(value, region);
};
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
- if (op->getParentRegion() != loopRegion)
+ if (op->getParentRegion() != region)
continue;
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
- if (!shouldMoveOutOfRegion(op, loopRegion) ||
+ if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
// Can only hoist pure ops (side-effect free) when there is an op with
@@ -143,30 +98,25 @@ size_t mlir::moveLoopInvariantCode(
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
continue;
- LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-
- auto moveWithoutGuard = isMemoryEffectFree(op) &&
- !dependsOnGuarded(op, definedOutsideGuard) &&
- isLoopWrapped;
- numMovedWithoutGuard += moveWithoutGuard;
-
- moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
- : loopRegion);
+ bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
+ if (moveWithoutGuard) {
+ LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
+ << " without guard\n");
+ moveOutOfRegionWithoutGuard(op);
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Moving loop-invariant op: " << *op << " with guard\n");
+ moveOutOfRegionWithGuard(op);
+ anyOpHoistedWithGuard = true;
+ }
++numMoved;
// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
- if (user->getParentRegion() == loopRegion)
+ if (user->getParentRegion() == region)
worklist.push(user);
}
-
- // Unwrap the loop if it was wrapped but no ops were moved in the guard.
- if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
- auto tripCountCheckUnwrapped = unwrapGuard();
- if (failed(tripCountCheckUnwrapped))
- llvm_unreachable("Should not fail unwrapping trip-count check");
- }
}
return numMoved;
@@ -179,14 +129,10 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
- return isSpeculatable(op) &&
- (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
- },
- [&]() { return loopLike.wrapInTripCountCheck(); },
- [&](Operation *op, Region *region) {
- op->moveBefore(region->getParentOp());
+ return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
},
- [&]() { return loopLike.unwrapTripCountCheck(); });
+ [&](Operation *op) { loopLike.moveOutOfLoop(op); },
+ [&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
}
namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index f2c4f09e1e7b22..6a1ca519dec1dd 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -595,48 +595,93 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
- // CHECK: arith.cmpi
- // CHECK-NEXT: test.always_speculatable_op
- // CHECK-NEXT: scf.if
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+ // CHECK-NEXT: scf.if %[[CMP]]
// CHECK-NEXT: test.speculatable_op_with_memread
- // CHECK-NEXT: scf.for
+ // CHECK: else
+ // CHECK-NEXT: ub.poison : i32
+ // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always_speculate = "test.always_speculatable_op"() : () -> i32
+ %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %add = arith.addi %acc, %i_cast : i32
+ %sum = arith.addi %add, %only_read : i32
+ scf.yield %sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
+func.func @test_speculatable_op_with_read_side_effect_multiple_result_success(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+ // CHECK-NEXT: scf.if %[[CMP]]
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK: else
+ // CHECK-NEXT: ub.poison : i32
+ // CHECK-NEXT: ub.poison : f32
+ // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
- %always = "test.always_speculatable_op"() : () -> i32
- %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %always_speculate = "test.always_speculatable_op"() : () -> i32
+ %only_read:2 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> (i32, f32)
%i_cast = arith.index_cast %i: index to i32
- %i_sum = arith.addi %acc, %i_cast : i32
- %test_sum = arith.addi %i_sum, %always_read : i32
- scf.yield %test_sum : i32
+ %add = arith.addi %acc, %i_cast : i32
+ %sum = arith.addi %add, %only_read#0 : i32
+ scf.yield %sum : i32
}
return %sum_result : i32
}
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
- // CHECK: arith.cmpi
- // CHECK-NEXT: test.always_speculatable_op
- // CHECK-NEXT: scf.if
+ // CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
+ // CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
+ // CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
// CHECK-NEXT: test.speculatable_op_with_memread
- // CHECK-NEXT: arith.addi
- // CHECK-NEXT: scf.for
+ // CHECK: else
+ // CHECK-NEXT: ub.poison : i32
+ // CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+ // CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
+ // CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
+ // CHECK: else
+ // CHECK-NEXT: ub.poison : i32
+ // CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+ // CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK: else
+ // CHECK-NEXT: ub.poison : i32
+ // CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
+ // CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
+ // CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
+ // CHECK: else
+ // CHECK-NEXT: ub.poison : i32
+ // CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
- %always = "test.always_speculatable_op"() : () -> i32
- %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
- %add = arith.addi %always_read, %cst_0 : i32
+ %always_speculate = "test.always_speculatable_op"() : () -> i32
+ %only_read_0 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %add_0 = arith.addi %always_speculate, %only_read_0 : i32
+ %only_read_1 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %add_1 = arith.addi %add_0, %only_read_1 : i32
%i_cast = arith.index_cast %i: index to i32
- %i_sum = arith.addi %acc, %i_cast : i32
- %test_sum = arith.addi %i_sum, %add : i32
- scf.yield %test_sum : i32
+ %sum = arith.addi %add_1, %i_cast : i32
+ scf.yield %sum : i32
}
return %sum_result : i32
}
@@ -652,13 +697,13 @@ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb:
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
- %always = "test.always_speculatable_op"() : () -> i32
- %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %always_speculate = "test.always_speculatable_op"() : () -> i32
+ %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
- %i_sum = arith.addi %acc, %i_cast : i32
- %test_sum = arith.addi %i_sum, %always_read : i32
- %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
- scf.yield %test_sum : i32
+ %add = arith.addi %acc, %i_cast : i32
+ %sum = arith.addi %add, %only_read : i32
+ %write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+ scf.yield %sum : i32
}
return %sum_result : i32
}
@@ -676,18 +721,18 @@ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_writ
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
- %always = "test.always_speculatable_op"() : () -> i32
- %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %always_speculate = "test.always_speculatable_op"() : () -> i32
+ %only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
- %i_sum = arith.addi %acc, %i_cast : i32
- %test_sum = arith.addi %i_sum, %always_read : i32
+ %add = arith.addi %acc, %i_cast : i32
+ %sum = arith.addi %add, %only_read : i32
scf.for %j = %lb to %ub step %step {
%eq42 = arith.cmpi eq, %j, %ind_42 : index
scf.if %eq42 {
%always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
}
}
- scf.yield %test_sum : i32
+ scf.yield %sum : i32
}
return %sum_result : i32
}
More information about the Mlir-commits
mailing list