[Mlir-commits] [mlir] [mlir] Don't hoist transfers from potentially zero trip loops (PR #112752)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 17 10:37:08 PDT 2024
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/112752
The hoistRedundantVectorTransfers function does not verification of loop bounds when hoisting vector transfers. This is not safe in general, since it is possible that the loop will have zero trip count. This PR uses ValueBounds to verify that the lower bound is less than the upper bound of the loop before hoisting.
Zero trip count loops can arise in GPU code generation, where a loop bound can be dependent on a thread id. If not all threads execute the loop body, then hoisting out of the loop can cause these threads to execute the transfers when they are not supposed to.
>From ca0bc32f7d13628ae688cf1290cdc6db99dee73a Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 08:56:35 -0500
Subject: [PATCH] [mlir] Don't hoist transfers from potentially zero trip loops
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Dialect/Linalg/Transforms/Hoisting.cpp | 46 ++++++++
mlir/test/Dialect/Linalg/hoisting.mlir | 108 ++++++++++++++----
2 files changed, 134 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 94f6b602987555..f382b2c13b94d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -208,6 +208,46 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
root->walk(
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+ // Find all loops that are certain to have non zero trip count. Any loops
+ // that are not part of this set cannot be hoisted from, since hoisting from
+ // a potentially zero trip count loop may cause a vector transfer to be
+ // executed when it shouldn't be.
+ llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
+ root->walk(
+ [&](LoopLikeOpInterface loopLike) {
+ std::optional<SmallVector<OpFoldResult>> lbs =
+ loopLike.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> ubs =
+ loopLike.getLoopUpperBounds();
+ // If loop bounds cannot be found, assume possibly zero trip count.
+ if (!lbs || !ubs) {
+ return;
+ }
+ // Otherwise, use ValueBounds to find the maximum lower bound and
+ // minimum upper bound. If the bounds are found, and maxLb is less
+ // than the minUb, then the loop will not have zero trip count.
+ for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
+ FailureOr<int64_t> maxLb =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, /*var=*/lb,
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ if (failed(maxLb)) {
+ return;
+ }
+ FailureOr<int64_t> minUb =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::LB, /*var=*/ub,
+ /*stopCondition=*/nullptr);
+ if (failed(minUb)) {
+ return;
+ }
+ if (minUb.value() <= maxLb.value()) {
+ return;
+ }
+ definiteNonZeroTripCountLoops.insert(loopLike);
+ }
+ });
+
root->walk([&](vector::TransferReadOp transferRead) {
if (!isa<MemRefType>(transferRead.getShapedType()))
return WalkResult::advance();
@@ -220,6 +260,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
return WalkResult::advance();
+ if (!definiteNonZeroTripCountLoops.contains(loop)) {
+ LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
+ << "\n");
+ return WalkResult::advance();
+ }
+
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
<< "\n");
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 241b8a486c012e..f1bd14d233bd59 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -8,21 +8,21 @@
// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
-// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
-// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
func.func @hoist_vector_transfer_pairs(
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
- %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+ %val: index, %step: index, %cmp: i1) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
-// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
+// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
-// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
+// CHECK: scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
@@ -92,15 +92,15 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
-// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
-// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
func.func @hoist_vector_transfer_pairs_disjoint(
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
- %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
+ %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index,
%step: index, %random_index : index, %cmp: i1) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
@@ -110,9 +110,9 @@ func.func @hoist_vector_transfer_pairs_disjoint(
// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
-// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
+// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
+// CHECK: scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
@@ -308,6 +308,62 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @no_hoisting_zero_trip_loop
+func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi32>, %lb: index, %ub: index) {
+ %c0_i32 = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // %lb and %ub are unbounded, so do not hoist.
+
+ // CHECK: scf.for {{.*}} {
+ // CHECK-NEXT: vector.transfer_read
+ // CHECK-NEXT: vector.transfer_write
+ scf.for %arg2 = %lb to %ub step %c1 {
+ %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+ vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+ }
+
+ // %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
+ // Since %lb_0 could be greater than %ub_0, do not hoist.
+ %lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
+ %ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub)
+
+ // CHECK: scf.for {{.*}} {
+ // CHECK-NEXT: vector.transfer_read
+ // CHECK-NEXT: vector.transfer_write
+ scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
+ %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+ vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+ }
+
+ // %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
+ // Since %lb_1 is guaranteed to be less than %ub_1, hoisting is possible.
+ %lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
+ %ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
+
+ // CHECK: vector.transfer_read
+ // CHECK: scf.for {{.*}} {
+ // CHECK-NEXT: "prevent.dce"
+ scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
+ %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+ "prevent.dce"(%read) : (vector<4xi32>) ->()
+ vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// Regression test - `vector.transfer_read` below should not be hoisted.
// Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca
// (read by `vector.transfer_read`) alias.
@@ -436,7 +492,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
-// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
+// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %[[I0:.+]]: index)
// CHECK: %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
// CHECK: %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
@@ -451,7 +507,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
- %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+ %buffer: memref<?x?xf32>, %step: index, %i0 : index) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%cst = arith.constant 0.0 : f32
%i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
%i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
@@ -494,7 +552,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-COUNT-2: vector.transfer_write
func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
- %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+ %buffer: memref<?x?xf32>, %step: index, %i0 : index) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%cst = arith.constant 0.0 : f32
%i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
@@ -534,7 +594,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: return
func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
- %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
+ %buffer: memref<?x?xf32>, %step: index, %i0 : index, %i1 : index) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%cst = arith.constant 0.0 : f32
%i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
%i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
@@ -571,7 +633,7 @@ module attributes {transform.with_named_sequence} {
// Test hoisting of vector.extract/vector.broadcast pairs
// CHECK-LABEL: func.func @hoist_vector_broadcasts
-// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -580,7 +642,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
-func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+func.func @hoist_vector_broadcasts(%step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
%extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -605,7 +669,7 @@ module attributes {transform.with_named_sequence} {
// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
// CHECK-LABEL: func.func @hoist_vector_broadcasts
-// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+// CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -614,7 +678,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
-func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+func.func @hoist_vector_broadcasts_dynamic(%step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
%extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -639,7 +705,7 @@ module attributes {transform.with_named_sequence} {
// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
-// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+// CHECK-SAME: (%{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
@@ -652,7 +718,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
-func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+func.func @hoist_vector_broadcasts_multiple(%step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 16 : index
%bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
%extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
%extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
More information about the Mlir-commits
mailing list