[Mlir-commits] [mlir] [mlir][Transforms] LISH: Improve bypass analysis for loop-like ops (PR #70623)
Matthias Springer
llvmlistbot at llvm.org
Tue Oct 31 19:07:15 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70623
>From 5feaa44d6f7ec66de821ee3d11cf7e8f3724e5eb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 1 Nov 2023 11:06:12 +0900
Subject: [PATCH] [mlir][WIP] Bypass analysis for loops
---
.../Utils/LoopInvariantCodeMotionUtils.cpp | 72 ++++++++++++++-----
.../loop-invariant-subset-hoisting.mlir | 35 +++++++++
2 files changed, 91 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 01318cf7328b543..53bdb7aafe41a0c 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -120,8 +120,10 @@ namespace {
class MatchingSubsets {
public:
/// Insert a subset op.
- void insert(SubsetOpInterface op) {
+ void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
allSubsetOps.push_back(op);
+ if (!collectHoistableOps)
+ return;
if (auto extractionOp =
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
insertExtractionOp(extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
});
}
+ /// Populate subset ops starting from the given region iter_arg. Return
+ /// "failure" if non-subset ops are found along the path to the loop yielding
+ /// op or if there is no single path to the tied yielded operand. If
+ /// `collectHoistableOps` is set to "false", subset ops are gathered
+ /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
+ LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg,
+ bool collectHoistableOps = true);
+
private:
/// Helper function for equivalence of tensor values. Since only insertion
/// subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
return nullptr;
}
-/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
-/// loop-like op and index into loop-invariant subset locations. Return the
-/// newly created loop op (that has extra iter_args) or the original loop op if
-/// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
- BlockArgument iterArg) {
- IRRewriter rewriter(loopLike.getContext());
+LogicalResult
+MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg,
+ bool collectHoistableOps) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
- auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
- int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
Value value = iterArg;
- MatchingSubsets subsets;
// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
// use-def chain starting from the region iter_arg are subset extraction or
@@ -249,21 +254,39 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
Value nextValue = {};
for (OpOperand &use : value.getUses()) {
+ if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
+ // Subset ops in nested loops are collected to check if there are only
+ // disjoint subset ops, but such subset ops are not subject to hoisting.
+ // To hoist subset ops from nested loops, the hoisting transformation
+ // should be run on the nested loop.
+ auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
+ if (!nestedIterArg)
+ return failure();
+ // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
+ // use-def chain starting at `nestedIterArg` and terminating in the
+ // tied, yielding operand.
+ if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
+ /*collectHoistableOps=*/false)))
+ return failure();
+ nextValue = nestedLoop.getTiedLoopResult(&use);
+ continue;
+ }
+
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
if (!subsetOp)
- return loopLike;
- subsets.insert(subsetOp);
+ return failure();
+ insert(subsetOp);
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
// The value must be used as a destination. (In case of a source, the
// entire tensor would be read, which would prevent any hoisting.)
if (&use != &insertionOp.getDestinationOperand())
- return loopLike;
+ return failure();
// There must be a single use-def chain from the region iter_arg to the
// terminator. I.e., only one insertion op. Branches are not supported.
if (nextValue)
- return loopLike;
+ return failure();
nextValue = insertionOp.getUpdatedDestination();
}
}
@@ -271,7 +294,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
// Nothing can be hoisted if the chain does not continue with loop yielding
// op or a subset insertion op.
if (!nextValue)
- return loopLike;
+ return failure();
value = nextValue;
}
@@ -279,6 +302,23 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
// swapping yield.)
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+ return failure();
+
+ return success();
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg) {
+ assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+ auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+ IRRewriter rewriter(loopLike.getContext());
+ MatchingSubsets subsets;
+ if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;
// Hoist all matching extraction-insertion pairs one-by-one.
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index 5cded4c99182c14..b9161f4e20d1927 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -235,3 +235,38 @@ func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
return %0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @nested_hoisting(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
+ // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
+ %4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
+ %5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
+ %6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t2]], %[[foo2]]
+ scf.yield %7 : tensor<?xf32>
+ }
+ // CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
+ scf.yield %4 : tensor<?xf32>
+ }
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
+ // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
+ // CHECK: return %[[insert2]]
+ return %0 : tensor<?xf32>
+}
More information about the Mlir-commits
mailing list