[Mlir-commits] [mlir] [llvm] [mlir][Transforms] Add loop-invariant subset hoisting (LISH) transformation (PR #70619)
Matthias Springer
llvmlistbot at llvm.org
Tue Oct 31 18:57:05 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70619
>From ac0ae73c1bd0e401e7952d0c40015b6739e565ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 1 Nov 2023 10:48:25 +0900
Subject: [PATCH] [mlir] Loop-invariant subset hoisting
---
.../Transforms/LoopInvariantCodeMotionUtils.h | 39 +++
mlir/include/mlir/Transforms/Passes.h | 3 +
mlir/include/mlir/Transforms/Passes.td | 5 +
.../Transforms/LoopInvariantCodeMotion.cpp | 20 ++
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 +
.../Utils/LoopInvariantCodeMotionUtils.cpp | 254 +++++++++++++++++-
.../loop-invariant-subset-hoisting.mlir | 237 ++++++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
8 files changed, 556 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index c7b816eb28faf5f..579054070f729b0 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -71,6 +71,45 @@ size_t moveLoopInvariantCode(
/// methods provided by the interface.
size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
+/// Hoist loop-invariant tensor subsets (subset extraction and subset insertion
+/// ops) from loop-like ops. Extraction ops are moved before the loop. Insertion
+/// ops are moved after the loop. The loop body operates on newly added region
+/// iter_args (one per extraction-insertion pair).
+///
+/// A subset extraction op (`SubsetExtractionOpInterface`) extracts from a
+/// tensor value at a subset. The result of the op may have an arbitrary type,
+/// i.e., not necessarily a tensor type. Example: "tensor.extract_slice".
+///
+/// A subset insertion op (`SubsetInsertionOpInterface`) inserts into a tensor
+/// value ("destination") at a subset. Example: "tensor.insert_slice".
+///
+/// Matching extraction-insertion subset ops can be hoisted from a loop if there
+/// are no other ops within the loop that operate on the same or on an
+/// overlapping subset. In particular, non-subset ops can prevent hoisting
+/// because the analysis does not know what subset they operate on.
+///
+/// Example:
+/// ```
+/// %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+/// %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+/// %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+/// %2 = tensor.insert_slice %1 into %t[0][5][1]
+/// : tensor<5xf32> into tensor<?xf32>
+/// scf.yield %2 : tensor<?xf32>
+/// }
+/// ```
+/// Is rewritten to:
+/// ```
+/// %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+/// %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+/// %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+/// scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+/// }
+/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+/// : tensor<5xf32> into tensor<?xf32>
+/// ```
+LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOPINVARIANTCODEMOTIONUTILS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 320932bb999561f..11f5b23e62c663b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -78,6 +78,9 @@ std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
/// instructions out of the loop.
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
+/// Creates a pass that hoists loop-invariant subset ops.
+std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass();
+
/// Creates a pass to strip debug information from a function.
std::unique_ptr<Pass> createStripDebugInfoPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 26d2ff3c30ded57..2d2d54fb8fb5eaa 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -329,6 +329,11 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
let constructor = "mlir::createLoopInvariantCodeMotionPass()";
}
+def LoopInvariantSubsetHoisting : Pass<"loop-invariant-subset-hoisting"> {
+ let summary = "Hoist loop invariant subset ops outside of the loop";
+ let constructor = "mlir::createLoopInvariantSubsetHoistingPass()";
+}
+
def Mem2Reg : Pass<"mem2reg"> {
let summary = "Promotes memory slots into values.";
let description = [{
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 854fde09bac796e..e6d8af8f05832d3 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -18,6 +18,7 @@
namespace mlir {
#define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION
+#define GEN_PASS_DEF_LOOPINVARIANTSUBSETHOISTING
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
@@ -29,6 +30,12 @@ struct LoopInvariantCodeMotion
: public impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> {
void runOnOperation() override;
};
+
+struct LoopInvariantSubsetHoisting
+ : public impl::LoopInvariantSubsetHoistingBase<
+ LoopInvariantSubsetHoisting> {
+ void runOnOperation() override;
+};
} // namespace
void LoopInvariantCodeMotion::runOnOperation() {
@@ -39,6 +46,19 @@ void LoopInvariantCodeMotion::runOnOperation() {
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
}
+void LoopInvariantSubsetHoisting::runOnOperation() {
+ // Walk through all loops in a function in innermost-loop-first order. This
+ // way, we first hoist from the inner loop, and place the ops in the outer
+ // loop, which in turn can be further hoisted from.
+ getOperation()->walk([&](LoopLikeOpInterface loopLike) {
+ (void)hoistLoopInvariantSubsets(loopLike);
+ });
+}
+
std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() {
return std::make_unique<LoopInvariantCodeMotion>();
}
+
+std::unique_ptr<Pass> mlir::createLoopInvariantSubsetHoistingPass() {
+ return std::make_unique<LoopInvariantSubsetHoisting>();
+}
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index efc7a5160b2399e..1c608e0634a67e2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -20,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
MLIRFunctionInterfaces
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
+ MLIRSubsetOpInterface
MLIRRewrite
)
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 080492da6ae4b97..01318cf7328b543 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -11,9 +11,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
#include <queue>
@@ -26,7 +29,7 @@ using namespace mlir;
/// loop (by means of calling definedOutside).
/// - the op has no side-effects.
static bool canBeHoisted(Operation *op,
- function_ref<bool(Value)> definedOutside) {
+ function_ref<bool(OpOperand &)> condition) {
// Do not move terminators.
if (op->hasTrait<OpTrait::IsTerminator>())
return false;
@@ -35,11 +38,11 @@ static bool canBeHoisted(Operation *op,
// defined outside of the loop or in a nested region, but not at the level of
// the loop body.
auto walkFn = [&](Operation *child) {
- for (Value operand : child->getOperands()) {
+ for (OpOperand &operand : child->getOpOperands()) {
// Ignore values defined in a nested region.
- if (op->isAncestor(operand.getParentRegion()->getParentOp()))
+ if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
continue;
- if (!definedOutside(operand))
+ if (!condition(operand))
return WalkResult::interrupt();
}
return WalkResult::advance();
@@ -47,6 +50,12 @@ static bool canBeHoisted(Operation *op,
return !op->walk(walkFn).wasInterrupted();
}
+static bool canBeHoisted(Operation *op,
+ function_ref<bool(Value)> definedOutside) {
+ return canBeHoisted(
+ op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
+}
+
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
@@ -105,3 +114,240 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}
+
+namespace {
+/// Helper data structure that keeps track of equivalent/disjoint subset ops.
+class MatchingSubsets {
+public:
+ /// Insert a subset op.
+ void insert(SubsetOpInterface op) {
+ allSubsetOps.push_back(op);
+ if (auto extractionOp =
+ dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
+ insertExtractionOp(extractionOp);
+ if (auto insertionOp =
+ dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
+ insertInsertionOp(insertionOp);
+ }
+
+ /// Return a range of matching extraction-insertion subset ops. If there is no
+ /// matching extraction/insertion op, the respective value is empty. Ops are
+ /// skipped if there are other subset ops that are not guaranteed to operate
+ /// on disjoint subsets.
+ auto getHoistableSubsetOps() {
+ return llvm::make_filter_range(
+ llvm::zip(extractions, insertions), [&](auto pair) {
+ auto [extractionOp, insertionOp] = pair;
+ // Hoist only if the extracted and inserted values have the same type.
+ if (extractionOp && insertionOp &&
+ extractionOp->getResult(0).getType() !=
+ insertionOp.getSourceOperand().get().getType())
+ return false;
+ // Hoist only if there are no conflicting subset ops.
+ return allDisjoint(extractionOp, insertionOp);
+ });
+ }
+
+private:
+ /// Helper function for equivalence of tensor values. Since only insertion
+ /// subset ops (that are also destination style ops) are followed when
+ /// traversing the SSA use-def chain, all tensor values are equivalent.
+ static bool isEquivalent(Value v1, Value v2) { return true; }
+
+ /// Return "true" if the subsets of the given extraction and insertion ops
+ /// are operating disjoint from the subsets that all other known subset ops
+ /// are operating on.
+ bool allDisjoint(SubsetExtractionOpInterface extractionOp,
+ SubsetInsertionOpInterface insertionOp) const {
+ for (SubsetOpInterface other : allSubsetOps) {
+ if (other == extractionOp || other == insertionOp)
+ continue;
+ if (extractionOp &&
+ !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
+ return false;
+ if (insertionOp &&
+ !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
+ return false;
+ }
+ return true;
+ }
+
+ /// Insert a subset extraction op. If the subset is equivalent to an existing
+ /// subset insertion op, pair them up. (If there is already a paired up subset
+ /// extraction op, overwrite the subset extraction op.)
+ void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
+ for (auto it : llvm::enumerate(insertions)) {
+ if (!it.value())
+ continue;
+ auto other = cast<SubsetOpInterface>(it.value().getOperation());
+ if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
+ extractions[it.index()] = extractionOp;
+ return;
+ }
+ }
+ // There is no known equivalent insertion op. Create a new entry.
+ extractions.push_back(extractionOp);
+ insertions.push_back({});
+ }
+
+ /// Insert a subset insertion op. If the subset is equivalent to an existing
+ /// subset extraction op, pair them up. (If there is already a paired up
+ /// subset insertion op, overwrite the subset insertion op.)
+ void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
+ for (auto it : llvm::enumerate(extractions)) {
+ if (!it.value())
+ continue;
+ auto other = cast<SubsetOpInterface>(it.value().getOperation());
+ if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
+ insertions[it.index()] = insertionOp;
+ return;
+ }
+ }
+ // There is no known equivalent extraction op. Create a new entry.
+ extractions.push_back({});
+ insertions.push_back(insertionOp);
+ }
+
+ SmallVector<SubsetExtractionOpInterface> extractions;
+ SmallVector<SubsetInsertionOpInterface> insertions;
+ SmallVector<SubsetOpInterface> allSubsetOps;
+};
+} // namespace
+
+/// If the given value has a single use by an op that is a terminator, return
+/// that use. Otherwise, return nullptr.
+static OpOperand *getSingleTerminatorUse(Value value) {
+ if (!value.hasOneUse())
+ return nullptr;
+ OpOperand &use = *value.getUses().begin();
+ if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
+ return &use;
+ return nullptr;
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg) {
+ IRRewriter rewriter(loopLike.getContext());
+ assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+ auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+ Value value = iterArg;
+ MatchingSubsets subsets;
+
+ // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
+ // use-def chain starting from the region iter_arg are subset extraction or
+ // subset insertion ops. The chain must terminate at the corresponding yield
+ // operand (e.g., no swapping of iter_args).
+ OpOperand *yieldedOperand = nullptr;
+ // Iterate until the single use of the current SSA value is a terminator,
+ // which is expected to be the yielding operation of the loop.
+ while (!(yieldedOperand = getSingleTerminatorUse(value))) {
+ Value nextValue = {};
+
+ for (OpOperand &use : value.getUses()) {
+ auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
+ if (!subsetOp)
+ return loopLike;
+ subsets.insert(subsetOp);
+
+ if (auto insertionOp =
+ dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
+ // The value must be used as a destination. (In case of a source, the
+ // entire tensor would be read, which would prevent any hoisting.)
+ if (&use != &insertionOp.getDestinationOperand())
+ return loopLike;
+ // There must be a single use-def chain from the region iter_arg to the
+ // terminator. I.e., only one insertion op. Branches are not supported.
+ if (nextValue)
+ return loopLike;
+ nextValue = insertionOp.getUpdatedDestination();
+ }
+ }
+
+ // Nothing can be hoisted if the chain does not continue with loop yielding
+ // op or a subset insertion op.
+ if (!nextValue)
+ return loopLike;
+ value = nextValue;
+ }
+
+ // Hoist only if the SSA use-def chain ends in the yielding terminator of the
+ // loop and the yielded value is the `idx`-th operand. (I.e., there is no
+ // swapping yield.)
+ if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+ return loopLike;
+
+ // Hoist all matching extraction-insertion pairs one-by-one.
+ for (auto it : subsets.getHoistableSubsetOps()) {
+ auto extractionOp = std::get<0>(it);
+ auto insertionOp = std::get<1>(it);
+
+ // Ops cannot be hoisted if they depend on loop-variant values.
+ if (extractionOp) {
+ if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
+ return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+ &operand == &extractionOp.getSourceOperand();
+ }))
+ extractionOp = {};
+ }
+ if (insertionOp) {
+ if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
+ return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+ &operand == &insertionOp.getSourceOperand() ||
+ &operand == &insertionOp.getDestinationOperand();
+ }))
+ insertionOp = {};
+ }
+
+ // Only hoist extraction-insertion pairs for now. Standalone extractions/
+ // insertions that are loop-invariant could be hoisted, but there may be
+ // easier ways to canonicalize the IR.
+ if (extractionOp && insertionOp) {
+ // Create a new loop with an additional iter_arg.
+ NewYieldValuesFn newYieldValuesFn =
+ [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
+ return {insertionOp.getSourceOperand().get()};
+ };
+ FailureOr<LoopLikeOpInterface> newLoop =
+ loopLike.replaceWithAdditionalYields(
+ rewriter, extractionOp.getResult(),
+ /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
+ if (failed(newLoop))
+ return loopLike;
+ loopLike = *newLoop;
+
+ // Hoist the extraction/insertion ops.
+ iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
+ OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
+ OpResult newLoopResult = loopLike.getLoopResults()->back();
+ extractionOp->moveBefore(loopLike);
+ insertionOp->moveAfter(loopLike);
+ insertionOp.getUpdatedDestination().replaceAllUsesWith(
+ insertionOp.getDestinationOperand().get());
+ extractionOp.getSourceOperand().set(
+ loopLike.getTiedLoopInit(iterArg)->get());
+ loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+ insertionOp.getSourceOperand().set(newLoopResult);
+ insertionOp.getDestinationOperand().set(loopResult);
+ }
+ }
+
+ return loopLike;
+}
+
+LoopLikeOpInterface
+mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+ // Note: As subset ops are getting hoisted, the number of region iter_args
+ // increases. This can enable further hoisting opportunities on the new
+ // iter_args.
+ for (int64_t i = 0;
+ i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
+ loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+ }
+ return loopLike;
+}
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
new file mode 100644
index 000000000000000..5cded4c99182c14
--- /dev/null
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt %s -split-input-file -loop-invariant-subset-hoisting | FileCheck %s
+
+// CHECK-LABEL: func @hoist_matching_extract_insert(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+ // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // CHECK: tensor.extract_slice %[[t]][9] [5] [1]
+ %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32>
+ "test.foo"(%standalone) : (tensor<5xf32>) -> ()
+
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t]], %[[foo]]
+ scf.yield %3 : tensor<?xf32>
+ }
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+
+ // CHECK: return %[[insert]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]]
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]]
+ // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32>
+
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]])
+ %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>)
+
+ %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32>
+ %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+
+ // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]]
+ scf.yield %insert2 : tensor<?xf32>
+ }
+ // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1]
+ // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1]
+
+ // CHECK: return %[[insert1]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_matching_chain(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ %sz = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1]
+ // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1]
+ // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32>
+ %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]])
+ // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]])
+ %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]]
+ scf.yield %6 : tensor<?xf32>
+ }
+ // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1]
+ // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1]
+
+ // CHECK: return %[[insert1]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_overlapping_subsets(
+func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ %sz1 = "test.foo"() : () -> (index)
+ %sz2 = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // These two slices are potentially overlapping. Do not hoist.
+ // CHECK: tensor.extract_slice
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32>
+ %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+ // CHECK: "test.foo"
+ // CHECK: "test.foo"
+ %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>)
+ // CHECK: tensor.insert_slice
+ // CHECK: tensor.insert_slice
+ %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %6 : tensor<?xf32>
+ }
+
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multiple_yields(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract1:.*]] = tensor.extract_slice
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice
+ // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]])
+ %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ // CHECK: "test.foo"
+ %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32>
+ }
+ // CHECK: tensor.insert_slice
+ // CHECK: tensor.insert_slice
+
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_swapping_yields(
+func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ // CHECK: "test.foo"
+ %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: tensor.insert_slice
+ // CHECK: tensor.insert_slice
+ %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+ // Swapping yields: do not hoist.
+ // CHECK: scf.yield
+ scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32>
+ }
+
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_subset_op(
+func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // If any value along the use-def chain from the region iter_arg to the
+ // terminator is used by a non-subset op, no subset op along that chain can
+ // be hoisted. That is because it is unknown which parts of the value are
+ // accessed by the non-subset op.
+ // CHECK: "test.non_subset_op"
+ "test.non_subset_op"(%t) : (tensor<?xf32>) -> ()
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: tensor.insert_slice
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %3 : tensor<?xf32>
+ }
+
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_loop_invariant_subset_op(
+func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // Subset ops that are not loop-invariant cannot be hoisted.
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: tensor.insert_slice
+ %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %3 : tensor<?xf32>
+ }
+
+ return %0 : tensor<?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 0448ecbef655fa5..0a2ae427169a99a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7035,6 +7035,7 @@ cc_library(
":MemorySlotInterfaces",
":Rewrite",
":SideEffectInterfaces",
+ ":SubsetOpInterface",
":Support",
":TransformsPassIncGen",
":config",
More information about the Mlir-commits
mailing list