[Mlir-commits] [mlir] [mlir] Don't hoist transfers from potentially zero trip loops (PR #112752)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 18 06:55:44 PDT 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/112752

>From 5c0aae2493199568e2e13e89760f8ba804228866 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 08:56:35 -0500
Subject: [PATCH 1/3] [mlir] Don't hoist transfers from potentially zero trip
 loops

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  45 ++++++++
 mlir/test/Dialect/Linalg/hoisting.mlir        | 108 ++++++++++++++----
 2 files changed, 133 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 94f6b602987555..92345e7695151c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -208,6 +208,45 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
     root->walk(
         [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
 
+    // Find all loops that are certain to have non zero trip count. Any loops
+    // that are not part of this set cannot be hoisted from, since hoisting from
+    // a potentially zero trip count loop may cause a vector transfer to be
+    // executed when it shouldn't be.
+    llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
+    root->walk([&](LoopLikeOpInterface loopLike) {
+      std::optional<SmallVector<OpFoldResult>> lbs =
+          loopLike.getLoopLowerBounds();
+      std::optional<SmallVector<OpFoldResult>> ubs =
+          loopLike.getLoopUpperBounds();
+      // If loop bounds cannot be found, assume possibly zero trip count.
+      if (!lbs || !ubs) {
+        return;
+      }
+      // Otherwise, use ValueBounds to find the maximum lower bound and
+      // minimum upper bound. If the bounds are found, and maxLb is less
+      // than the minUb, then the loop will not have zero trip count.
+      for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
+        FailureOr<int64_t> maxLb =
+            ValueBoundsConstraintSet::computeConstantBound(
+                presburger::BoundType::UB, /*var=*/lb,
+                /*stopCondition=*/nullptr, /*closedUB=*/true);
+        if (failed(maxLb)) {
+          return;
+        }
+        FailureOr<int64_t> minUb =
+            ValueBoundsConstraintSet::computeConstantBound(
+                presburger::BoundType::LB, /*var=*/ub,
+                /*stopCondition=*/nullptr);
+        if (failed(minUb)) {
+          return;
+        }
+        if (minUb.value() <= maxLb.value()) {
+          return;
+        }
+        definiteNonZeroTripCountLoops.insert(loopLike);
+      }
+    });
+
     root->walk([&](vector::TransferReadOp transferRead) {
       if (!isa<MemRefType>(transferRead.getShapedType()))
         return WalkResult::advance();
@@ -220,6 +259,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
       if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
         return WalkResult::advance();
 
+      if (!definiteNonZeroTripCountLoops.contains(loop)) {
+        LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
+                          << "\n");
+        return WalkResult::advance();
+      }
+
       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
                         << "\n");
 
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 241b8a486c012e..f1bd14d233bd59 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -8,21 +8,21 @@
 //  CHECK-SAME:   %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
 func.func @hoist_vector_transfer_pairs(
     %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
     %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+    %val: index, %step: index, %cmp: i1) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
 
 // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
-// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
+// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
 // CHECK:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
-// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
+// CHECK:   scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
 // CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
 // CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
 // CHECK:     "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
@@ -92,15 +92,15 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[RANDOM:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
 func.func @hoist_vector_transfer_pairs_disjoint(
     %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
-    %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
+    %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index,
     %step: index, %random_index : index, %cmp: i1) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
@@ -110,9 +110,9 @@ func.func @hoist_vector_transfer_pairs_disjoint(
 // CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
 // CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
 // CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
-// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
+// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
 //  CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
+// CHECK:   scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
 //  CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
 // CHECK:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
 // CHECK:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
@@ -308,6 +308,62 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL:  func.func @no_hoisting_zero_trip_loop
+func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi32>, %lb: index, %ub: index) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // %lb and %ub are unbounded, so do not hoist.
+
+  // CHECK:       scf.for {{.*}} {
+  // CHECK-NEXT:    vector.transfer_read
+  // CHECK-NEXT:    vector.transfer_write
+  scf.for %arg2 = %lb to %ub step %c1 {
+    %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+    vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+  }
+
+  // %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
+  // Since %lb_0 could be greater than %ub_0, do not hoist.
+  %lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
+  %ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub)
+
+  // CHECK:       scf.for {{.*}} {
+  // CHECK-NEXT:    vector.transfer_read
+  // CHECK-NEXT:    vector.transfer_write
+  scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
+    %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+    vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+  }
+
+  // %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
+  // Since %lb_1 is guaranteed to be less than %ub_1, hoisting is possible.
+  %lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
+  %ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
+
+  // CHECK:    vector.transfer_read
+  // CHECK:       scf.for {{.*}} {
+  // CHECK-NEXT:    "prevent.dce"
+  scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
+    %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+    "prevent.dce"(%read) : (vector<4xi32>) ->()
+    vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // Regression test - `vector.transfer_read` below should not be hoisted.
 // Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca
 // (read by `vector.transfer_read`) alias.
@@ -436,7 +492,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
 
 //   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
-//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
+//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %[[I0:.+]]: index)
 
 //         CHECK:   %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
 //         CHECK:   %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
@@ -451,7 +507,9 @@ module attributes {transform.with_named_sequence} {
 //         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
 
 func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
-    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+    %buffer: memref<?x?xf32>, %step: index, %i0 : index) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %cst = arith.constant 0.0 : f32
   %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
   %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
@@ -494,7 +552,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-COUNT-2:     vector.transfer_write
 
 func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
-    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+    %buffer: memref<?x?xf32>, %step: index, %i0 : index) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %cst = arith.constant 0.0 : f32
   %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
 
@@ -534,7 +594,9 @@ module attributes {transform.with_named_sequence} {
 //         CHECK:   return
 
 func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
-    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
+    %buffer: memref<?x?xf32>, %step: index, %i0 : index, %i1 : index) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %cst = arith.constant 0.0 : f32
   %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
   %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
@@ -571,7 +633,7 @@ module attributes {transform.with_named_sequence} {
 // Test hoisting of vector.extract/vector.broadcast pairs
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts
-//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+//       CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
 //       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
 //       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -580,7 +642,9 @@ module attributes {transform.with_named_sequence} {
 //       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
 //       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
 
-func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+func.func @hoist_vector_broadcasts(%step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
     %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
     %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -605,7 +669,7 @@ module attributes {transform.with_named_sequence} {
 // Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts
-//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+//       CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
 //       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
 //       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -614,7 +678,9 @@ module attributes {transform.with_named_sequence} {
 //       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
 //       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
 
-func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+func.func @hoist_vector_broadcasts_dynamic(%step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
     %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
     %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -639,7 +705,7 @@ module attributes {transform.with_named_sequence} {
 // Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts_multiple
-//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+//       CHECK-SAME: (%{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
 //       CHECK-SAME:  %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
 //       CHECK-DAG:     %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-DAG:     %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
@@ -652,7 +718,9 @@ module attributes {transform.with_named_sequence} {
 //       CHECK-DAG:     %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
 //       CHECK-NEXT:    return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
 
-func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
+func.func @hoist_vector_broadcasts_multiple(%step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 16 : index
   %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
     %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
     %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>

>From 9e1f7a1dccf213565329909af047d840799fe9f3 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 16:38:12 -0400
Subject: [PATCH 2/3] add option for trip count verification

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  6 +-
 .../mlir/Dialect/Linalg/Transforms/Hoisting.h |  6 +-
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 67 ++++++++++---------
 mlir/test/Dialect/Linalg/hoisting.mlir        | 67 ++++++++-----------
 5 files changed, 72 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 0915bbde3072b0..040c04b0410ecf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2294,13 +2294,15 @@ def HoistRedundantVectorTransfersOp :
     function op.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                   UnitAttr:$verify_non_zero_trip);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
 
   let builders = [
-    OpBuilder<(ins "Value":$target)>,
+    OpBuilder<(ins "Value":$target,
+               CArg<"bool", "false">:$verify_non_zero_trip)>,
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 236c2ce7d48e39..9e17bf4fb0e449 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -29,6 +29,9 @@ namespace linalg {
 ///   4. The source operands for vector.transfer_{read|write} do not originate
 ///   from Ops implementing ViewLikeOpInterface (to reduce the risk of
 ///   aliasing).
+///   5. If `verifyNonZeroTrip` is true, then the lower bound of the loop must
+///   be statically smaller than the upper bound of the loop, guaranteeing that
+///   the loop body will execute at least once.
 /// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
 /// function on the candidate loop above which to hoist. Hoisting the transfers
 /// results in scf::ForOp yielding the value that originally transited through
@@ -41,7 +44,8 @@ namespace linalg {
 ///
 /// WARNING: This hoisting does not model parallelism and is generally incorrect
 /// when used on distributed loops with memref semantics!
-void hoistRedundantVectorTransfers(Operation *root);
+void hoistRedundantVectorTransfers(Operation *root,
+                                   bool verifyNonZeroTrip = false);
 
 /// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
 /// scf::ForOp iteratively, if the following conditions are met:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ad72b5d7beccde..1f1d8ad89ae2b9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3558,7 +3558,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
   // WARNING: This hoisting does not model parallelism and is generally
   // incorrect when used on distributed loops with memref semantics!
   // TODO: obsolete and should be retired.
-  linalg::hoistRedundantVectorTransfers(target);
+  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
   results.push_back(target);
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 92345e7695151c..6e1160a510a35b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -199,7 +199,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
   return true;
 }
 
-void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
+void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
+                                                 bool verifyNonZeroTrip) {
   bool changed = true;
   while (changed) {
     changed = false;
@@ -213,39 +214,41 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
     // a potentially zero trip count loop may cause a vector transfer to be
     // executed when it shouldn't be.
     llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
-    root->walk([&](LoopLikeOpInterface loopLike) {
-      std::optional<SmallVector<OpFoldResult>> lbs =
-          loopLike.getLoopLowerBounds();
-      std::optional<SmallVector<OpFoldResult>> ubs =
-          loopLike.getLoopUpperBounds();
-      // If loop bounds cannot be found, assume possibly zero trip count.
-      if (!lbs || !ubs) {
-        return;
-      }
-      // Otherwise, use ValueBounds to find the maximum lower bound and
-      // minimum upper bound. If the bounds are found, and maxLb is less
-      // than the minUb, then the loop will not have zero trip count.
-      for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
-        FailureOr<int64_t> maxLb =
-            ValueBoundsConstraintSet::computeConstantBound(
-                presburger::BoundType::UB, /*var=*/lb,
-                /*stopCondition=*/nullptr, /*closedUB=*/true);
-        if (failed(maxLb)) {
-          return;
-        }
-        FailureOr<int64_t> minUb =
-            ValueBoundsConstraintSet::computeConstantBound(
-                presburger::BoundType::LB, /*var=*/ub,
-                /*stopCondition=*/nullptr);
-        if (failed(minUb)) {
+    if (verifyNonZeroTrip) {
+      root->walk([&](LoopLikeOpInterface loopLike) {
+        std::optional<SmallVector<OpFoldResult>> lbs =
+            loopLike.getLoopLowerBounds();
+        std::optional<SmallVector<OpFoldResult>> ubs =
+            loopLike.getLoopUpperBounds();
+        // If loop bounds cannot be found, assume possibly zero trip count.
+        if (!lbs || !ubs) {
           return;
         }
-        if (minUb.value() <= maxLb.value()) {
-          return;
+        // Otherwise, use ValueBounds to find the maximum lower bound and
+        // minimum upper bound. If the bounds are found, and maxLb is less
+        // than the minUb, then the loop will not have zero trip count.
+        for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
+          FailureOr<int64_t> maxLb =
+              ValueBoundsConstraintSet::computeConstantBound(
+                  presburger::BoundType::UB, /*var=*/lb,
+                  /*stopCondition=*/nullptr, /*closedUB=*/true);
+          if (failed(maxLb)) {
+            return;
+          }
+          FailureOr<int64_t> minUb =
+              ValueBoundsConstraintSet::computeConstantBound(
+                  presburger::BoundType::LB, /*var=*/ub,
+                  /*stopCondition=*/nullptr);
+          if (failed(minUb)) {
+            return;
+          }
+          if (minUb.value() <= maxLb.value()) {
+            return;
+          }
+          definiteNonZeroTripCountLoops.insert(loopLike);
         }
-        definiteNonZeroTripCountLoops.insert(loopLike);
-      }
-    });
+      });
+    }
 
     root->walk([&](vector::TransferReadOp transferRead) {
       if (!isa<MemRefType>(transferRead.getShapedType()))
@@ -259,7 +262,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
       if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
         return WalkResult::advance();
 
-      if (!definiteNonZeroTripCountLoops.contains(loop)) {
+      if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
         LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
                           << "\n");
         return WalkResult::advance();
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index f1bd14d233bd59..e39e2d10426491 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -8,21 +8,21 @@
 //  CHECK-SAME:   %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
 func.func @hoist_vector_transfer_pairs(
     %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
     %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
-    %val: index, %step: index, %cmp: i1) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
 
 // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
-// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
 // CHECK:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
-// CHECK:   scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
+// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
 // CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
 // CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
 // CHECK:     "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
@@ -92,15 +92,15 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[RANDOM:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
 func.func @hoist_vector_transfer_pairs_disjoint(
     %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
-    %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index,
+    %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
     %step: index, %random_index : index, %cmp: i1) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
@@ -110,9 +110,9 @@ func.func @hoist_vector_transfer_pairs_disjoint(
 // CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
 // CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
 // CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
-// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
 //  CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK:   scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
+// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
 //  CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
 // CHECK:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
 // CHECK:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
@@ -309,7 +309,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL:  func.func @no_hoisting_zero_trip_loop
-func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi32>, %lb: index, %ub: index) {
+func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %lb: index, %ub: index) {
   %c0_i32 = arith.constant 0 : i32
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -317,10 +317,10 @@ func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi3
 
   // CHECK:       scf.for {{.*}} {
   // CHECK-NEXT:    vector.transfer_read
-  // CHECK-NEXT:    vector.transfer_write
+  // CHECK-NEXT:    "prevent.dce"
   scf.for %arg2 = %lb to %ub step %c1 {
     %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
-    vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+    "prevent.dce"(%read) : (vector<4xi32>) ->()
   }
 
   // %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
@@ -330,10 +330,10 @@ func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi3
 
   // CHECK:       scf.for {{.*}} {
   // CHECK-NEXT:    vector.transfer_read
-  // CHECK-NEXT:    vector.transfer_write
+  // CHECK-NEXT:    "prevent.dce"
   scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
     %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
-    vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
+    "prevent.dce"(%read) : (vector<4xi32>) ->()
   }
 
   // %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
@@ -341,13 +341,12 @@ func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi3
   %lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
   %ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
 
-  // CHECK:    vector.transfer_read
+  // CHECK:       vector.transfer_read
   // CHECK:       scf.for {{.*}} {
   // CHECK-NEXT:    "prevent.dce"
   scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
     %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
     "prevent.dce"(%read) : (vector<4xi32>) ->()
-    vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
   }
   return
 }
@@ -356,7 +355,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.hoist_redundant_vector_transfers %0
+    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
       : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -492,7 +491,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
 
 //   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
-//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %[[I0:.+]]: index)
+//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
 
 //         CHECK:   %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
 //         CHECK:   %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
@@ -507,9 +506,7 @@ module attributes {transform.with_named_sequence} {
 //         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
 
 func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
-    %buffer: memref<?x?xf32>, %step: index, %i0 : index) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
   %cst = arith.constant 0.0 : f32
   %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
   %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
@@ -552,9 +549,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-COUNT-2:     vector.transfer_write
 
 func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
-    %buffer: memref<?x?xf32>, %step: index, %i0 : index) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
   %cst = arith.constant 0.0 : f32
   %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
 
@@ -594,9 +589,7 @@ module attributes {transform.with_named_sequence} {
 //         CHECK:   return
 
 func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
-    %buffer: memref<?x?xf32>, %step: index, %i0 : index, %i1 : index) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
   %cst = arith.constant 0.0 : f32
   %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
   %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
@@ -633,7 +626,7 @@ module attributes {transform.with_named_sequence} {
 // Test hoisting of vector.extract/vector.broadcast pairs
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts
-//       CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
 //       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
 //       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -642,9 +635,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
 //       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
 
-func.func @hoist_vector_broadcasts(%step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
   %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
     %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
     %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -669,7 +660,7 @@ module attributes {transform.with_named_sequence} {
 // Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts
-//       CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
 //       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
 //       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -678,9 +669,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
 //       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
 
-func.func @hoist_vector_broadcasts_dynamic(%step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
   %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
     %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
     %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -705,7 +694,7 @@ module attributes {transform.with_named_sequence} {
 // Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts_multiple
-//       CHECK-SAME: (%{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
 //       CHECK-SAME:  %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
 //       CHECK-DAG:     %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-DAG:     %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
@@ -718,9 +707,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK-DAG:     %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
 //       CHECK-NEXT:    return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
 
-func.func @hoist_vector_broadcasts_multiple(%step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 16 : index
+func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
   %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
     %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
     %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>

>From 0ea1156d72e22ac8929b79d6dad02f344ec4742e Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 18 Oct 2024 09:55:09 -0400
Subject: [PATCH 3/3] address comments

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  9 +--
 mlir/test/Dialect/Linalg/hoisting.mlir        | 76 ++++++++++++++-----
 2 files changed, 62 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 6e1160a510a35b..85bade0a444928 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -221,24 +221,23 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
         std::optional<SmallVector<OpFoldResult>> ubs =
             loopLike.getLoopUpperBounds();
         // If loop bounds cannot be found, assume possibly zero trip count.
-        if (!lbs || !ubs) {
+        if (!lbs || !ubs)
           return;
-        }
+
         // Otherwise, use ValueBounds to find the maximum lower bound and
         // minimum upper bound. If the bounds are found, and maxLb is less
         // than the minUb, then the loop will not have zero trip count.
         for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
           FailureOr<int64_t> maxLb =
               ValueBoundsConstraintSet::computeConstantBound(
-                  presburger::BoundType::UB, /*var=*/lb,
+                  presburger::BoundType::UB, lb,
                   /*stopCondition=*/nullptr, /*closedUB=*/true);
           if (failed(maxLb)) {
             return;
           }
           FailureOr<int64_t> minUb =
               ValueBoundsConstraintSet::computeConstantBound(
-                  presburger::BoundType::LB, /*var=*/ub,
-                  /*stopCondition=*/nullptr);
+                  presburger::BoundType::LB, ub);
           if (failed(minUb)) {
             return;
           }
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index e39e2d10426491..3330ce70f42202 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -308,20 +308,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL:  func.func @no_hoisting_zero_trip_loop
-func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %lb: index, %ub: index) {
+// CHECK-LABEL:  func.func @no_hoisting_unknown_bound_loop
+func.func @no_hoisting_unknown_bound_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
   %c0_i32 = arith.constant 0 : i32
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // %lb and %ub are unbounded, so do not hoist.
 
+  // %lb and %ub are unbounded, so do not hoist.
   // CHECK:       scf.for {{.*}} {
   // CHECK-NEXT:    vector.transfer_read
-  // CHECK-NEXT:    "prevent.dce"
+  // CHECK-NEXT:    "test.some_use"
   scf.for %arg2 = %lb to %ub step %c1 {
-    %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
-    "prevent.dce"(%read) : (vector<4xi32>) ->()
+    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+    "test.some_use"(%read) : (vector<4xi32>) ->()
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
   }
+}
+
+// -----
+
+// CHECK-LABEL:  func.func @no_hoisting_possibly_zero_trip_loop
+func.func @no_hoisting_possibly_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
 
   // %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
   // Since %lb_0 could be greater than %ub_0, do not hoist.
@@ -330,23 +350,43 @@ func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %lb: index, %ub: in
 
   // CHECK:       scf.for {{.*}} {
   // CHECK-NEXT:    vector.transfer_read
-  // CHECK-NEXT:    "prevent.dce"
+  // CHECK-NEXT:    "test.some_use"
   scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
-    %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
-    "prevent.dce"(%read) : (vector<4xi32>) ->()
+    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+    "test.some_use"(%read) : (vector<4xi32>) ->()
   }
+  return
+}
 
-  // %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
-  // Since %lb_1 is guaranteed to be less than %ub_1, hoisting is possible.
-  %lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
-  %ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL:  func.func @hoisting_non_zero_trip_loop
+func.func @hoisting_non_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // %lb_0 is in range [%lb, 4], and %ub_0 is in range [8, %ub].
+  // Since %lb_0 is guaranteed to be less than %ub_0, hoisting is possible.
+  %lb_0 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
+  %ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
 
   // CHECK:       vector.transfer_read
   // CHECK:       scf.for {{.*}} {
-  // CHECK-NEXT:    "prevent.dce"
-  scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
-    %read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
-    "prevent.dce"(%read) : (vector<4xi32>) ->()
+  // CHECK-NEXT:    "test.some_use"
+  scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
+    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
+    "test.some_use"(%read) : (vector<4xi32>) ->()
   }
   return
 }
@@ -421,7 +461,7 @@ func.func @no_hoisting_collapse_shape_2(%vec: vector<1x12x1xi32>) {
     %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x12x1xi32> into memref<12xi32>
     vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x12x1xi32>, memref<1x12x1xi32>
     %read = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<12xi32>, vector<12xi32>
-    "prevent.dce"(%read) : (vector<12xi32>) ->()
+    "test.some_use"(%read) : (vector<12xi32>) ->()
   }
   return
 }



More information about the Mlir-commits mailing list