[Mlir-commits] [mlir] [mlir][vector] Enable transfer op hoisting with dynamic indices (PR #68500)

Lei Zhang llvmlistbot at llvm.org
Sun Oct 15 16:00:16 PDT 2023


https://github.com/antiagainst updated https://github.com/llvm/llvm-project/pull/68500

>From f141bc9b1f7070f92067d5e19b38db3c9c4d5e67 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sat, 7 Oct 2023 15:12:58 -0700
Subject: [PATCH 1/6] [mlir][vector] Enable transfer op hoisting with dynamic
 indices

Recent changes (https://github.com/llvm/llvm-project/pull/66930)
disabled vector transfer ops hoisting with view-like intermediate
ops. The recommended way is to fold subview ops into transfer
op indices before invoking hoisting. That would mean now we
see transfer op indices involving dynamic values, instead of
static constant values before with subview ops. Therefore hoisting
won't kick in anymore. This breaks downstream users.

To fix it, this commit enables hoisting transfer ops with dynamic
indices by using `ValueBoundsConstraintSet` to prove ranges are
disjoint in `isDisjointTransferIndices`. Given that utility is
used in many places including op folders, right now we introduce
a flag to it and only set as true for "heavy" transforms in hoisting
and load-store forwarding.
---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.h    |  12 +-
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  19 ++-
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  10 ++
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  |   6 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  12 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  66 +++++++--
 .../Transforms/VectorTransferOpTransforms.cpp |   6 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  27 ++--
 mlir/test/Dialect/Linalg/hoisting.mlir        | 128 ++++++++++++++++++
 .../Dialect/Vector/vector-transferop-opt.mlir | 103 ++++++++++++++
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  30 ++--
 11 files changed, 362 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
index 5d4774861bdfd37..7c11c13380fadf3 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
@@ -18,16 +18,18 @@ class Value;
 namespace affine {
 void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
 
-/// Compute whether the given values are equal. Return "failure" if equality
-/// could not be determined. `value1`/`value2` must be index-typed.
+/// Compute a constant delta of the given two values. Return "failure" if we
+/// cannot determine. / `value1`/`value2` must be index-typed.
 ///
-/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work
-/// around limitations in `FlatLinearConstraints`, this function fully composes
+/// This function is similar to
+/// `ValueBoundsConstraintSet::computeConstantDistance`. To work around
+/// limitations in `FlatLinearConstraints`, this function fully composes
 /// `value1` and `value2` (if they are the result of affine.apply ops) before
 /// populating the constraint set. The folding/composing logic can see
 /// opportunities for simplifications that the constraint set implementation
 /// cannot see.
-FailureOr<bool> fullyComposeAndCheckIfEqual(Value value1, Value value2);
+FailureOr<int64_t> fullyComposeAndComputeConstantDelta(Value value1,
+                                                       Value value2);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index fc0c80036ff79ad..9ab20e20d975429 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -105,16 +105,23 @@ bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
 /// op.
 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
 
-/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
-/// to have the same tensor/memref. This allows comparing operations accessing
-/// different tensors.
+/// Return true if we can prove that the transfer operations access disjoint
+/// memory, without requring the accessed tensor/memref to be the same.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
-                               VectorTransferOpInterface transferB);
+                               VectorTransferOpInterface transferB,
+                               bool testDynamicValueUsingBounds = false);
 
 /// Return true if we can prove that the transfer operations access disjoint
-/// memory.
+/// memory, requiring the operations to access the same tensor/memref.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
-                           VectorTransferOpInterface transferB);
+                           VectorTransferOpInterface transferB,
+                           bool testDynamicValueUsingBounds = false);
 
 /// Return the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 2687d79aec68ebb..8f11c563e0cbd91 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -176,6 +176,16 @@ class ValueBoundsConstraintSet {
       presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
       StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
+  /// Compute a constant delta between the given two values. Return "failure"
+  /// if a constant delta could not be determined.
+  ///
+  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
+  /// index-typed.
+  static FailureOr<int64_t>
+  computeConstantDelta(Value value1, Value value2,
+                       std::optional<int64_t> dim1 = std::nullopt,
+                       std::optional<int64_t> dim2 = std::nullopt);
+
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index d47c8eb8ccb4272..e705d16e08ffcd3 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -103,8 +103,8 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
   });
 }
 
-FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
-                                                          Value value2) {
+FailureOr<int64_t>
+mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   assert(value1.getType().isIndex() && "expected index type");
   assert(value2.getType().isIndex() && "expected index type");
 
@@ -127,5 +127,5 @@ FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
       presburger::BoundType::EQ, map, valueDims);
   if (failed(bound))
     return failure();
-  return *bound == 0;
+  return *bound;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 221bec713b38aa3..cbb2c507de69f9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -173,16 +173,16 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
         if (auto transferWriteUse =
                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
           if (!vector::isDisjointTransferSet(
-                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
-                  cast<VectorTransferOpInterface>(
-                      transferWriteUse.getOperation())))
+                  cast<VectorTransferOpInterface>(*transferWrite),
+                  cast<VectorTransferOpInterface>(*transferWriteUse),
+                  /*testDynamicValueUsingBounds=*/true))
             return WalkResult::advance();
         } else if (auto transferReadUse =
                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
           if (!vector::isDisjointTransferSet(
-                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
-                  cast<VectorTransferOpInterface>(
-                      transferReadUse.getOperation())))
+                  cast<VectorTransferOpInterface>(*transferWrite),
+                  cast<VectorTransferOpInterface>(*transferReadUse),
+                  /*testDynamicValueUsingBounds=*/true))
             return WalkResult::advance();
         } else {
           // Unknown use, we cannot prove that it doesn't alias with the
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 044b6cc07d3d629..0f66d327b26863b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -30,6 +31,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -168,39 +170,77 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
 }
 
 bool mlir::vector::isDisjointTransferIndices(
-    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
+    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
+    bool testDynamicValueUsingBounds) {
   // For simplicity only look at transfer of same type.
   if (transferA.getVectorType() != transferB.getVectorType())
     return false;
   unsigned rankOffset = transferA.getLeadingShapedRank();
   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
-    auto indexA = getConstantIntValue(transferA.indices()[i]);
-    auto indexB = getConstantIntValue(transferB.indices()[i]);
-    // If any of the indices are dynamic we cannot prove anything.
-    if (!indexA.has_value() || !indexB.has_value())
-      continue;
+    Value indexA = transferA.indices()[i];
+    Value indexB = transferB.indices()[i];
+    std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
+    std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
 
     if (i < rankOffset) {
       // For leading dimensions, if we can prove that index are different we
       // know we are accessing disjoint slices.
-      if (*indexA != *indexB)
-        return true;
+      if (cstIndexA.has_value() && cstIndexB.has_value()) {
+        if (*cstIndexA != *cstIndexB)
+          return true;
+        continue;
+      }
+      if (testDynamicValueUsingBounds) {
+        // First try to see if we can fully compose and simplify the affine
+        // expression as a fast track.
+        FailureOr<uint64_t> delta =
+            affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
+        if (succeeded(delta) && *delta != 0)
+          return true;
+
+        FailureOr<bool> testEqual =
+            ValueBoundsConstraintSet::areEqual(indexA, indexB);
+        if (succeeded(testEqual) && !testEqual.value())
+          return true;
+      }
     } else {
       // For this dimension, we slice a part of the memref we need to make sure
       // the intervals accessed don't overlap.
-      int64_t distance = std::abs(*indexA - *indexB);
-      if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
-        return true;
+      int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
+      if (cstIndexA.has_value() && cstIndexB.has_value()) {
+        int64_t distance = std::abs(*cstIndexA - *cstIndexB);
+        if (distance >= vectorDim)
+          return true;
+        continue;
+      }
+      if (testDynamicValueUsingBounds) {
+        // First try to see if we can fully compose and simplify the affine
+        // expression as a fast track.
+        FailureOr<int64_t> delta =
+            affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
+        if (succeeded(delta) && *delta >= vectorDim)
+          return true;
+
+        FailureOr<int64_t> computeDelta =
+            ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
+        if (succeeded(computeDelta)) {
+          int64_t delta = std::abs(computeDelta.value());
+          if (delta >= vectorDim)
+            return true;
+        }
+      }
     }
   }
   return false;
 }
 
 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
-                                         VectorTransferOpInterface transferB) {
+                                         VectorTransferOpInterface transferB,
+                                         bool testDynamicValueUsingBounds) {
   if (transferA.source() != transferB.source())
     return false;
-  return isDisjointTransferIndices(transferA, transferB);
+  return isDisjointTransferIndices(transferA, transferB,
+                                   testDynamicValueUsingBounds);
 }
 
 // Helper to iterate over n-D vector slice elements. Calculate the next
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 603b88f11c8e007..a5f1b28152b9bde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -142,7 +142,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       // Don't need to consider disjoint accesses.
       if (vector::isDisjointTransferSet(
               cast<VectorTransferOpInterface>(write.getOperation()),
-              cast<VectorTransferOpInterface>(transferOp.getOperation())))
+              cast<VectorTransferOpInterface>(transferOp.getOperation()),
+              /*testDynamicValueUsingBounds=*/true))
         continue;
     }
     blockingAccesses.push_back(user);
@@ -217,7 +218,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
       // the write.
       if (vector::isDisjointTransferSet(
               cast<VectorTransferOpInterface>(write.getOperation()),
-              cast<VectorTransferOpInterface>(read.getOperation())))
+              cast<VectorTransferOpInterface>(read.getOperation()),
+              /*testDynamicValueUsingBounds=*/true))
         continue;
       if (write.getSource() == read.getSource() &&
           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index c00ee0315a9639a..ff941115219f68b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -484,25 +484,32 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   return failure();
 }
 
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
-                                   std::optional<int64_t> dim1,
-                                   std::optional<int64_t> dim2) {
+FailureOr<int64_t>
+ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
+                                               std::optional<int64_t> dim1,
+                                               std::optional<int64_t> dim2) {
 #ifndef NDEBUG
   assertValidValueDim(value1, dim1);
   assertValidValueDim(value2, dim2);
 #endif // NDEBUG
 
-  // Subtract the two values/dimensions from each other. If the result is 0,
-  // both are equal.
   Builder b(value1.getContext());
   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
-  FailureOr<int64_t> bound = computeConstantBound(
-      presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
-  if (failed(bound))
+  return computeConstantBound(presburger::BoundType::EQ, map,
+                              {{value1, dim1}, {value2, dim2}});
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+                                   std::optional<int64_t> dim1,
+                                   std::optional<int64_t> dim2) {
+  // Subtract the two values/dimensions from each other. If the result is 0,
+  // both are equal.
+  FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
+  if (failed(delta))
     return failure();
-  return *bound == 0;
+  return *delta == 0;
 }
 
 ValueBoundsConstraintSet::BoundBuilder &
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 7d0c3648c344b1d..9aaece6a201f799 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -872,3 +872,131 @@ transform.sequence failures(propagate) {
   transform.structured.hoist_redundant_vector_transfers %0
     : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+// Test that we can hoist out 1-D read-write pairs whose indices are dynamic values.
+
+// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
+
+//         CHECK:   %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
+//         CHECK:   %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
+//         CHECK:   %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]]
+//         CHECK:   %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+//         CHECK:   %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+// CHECK-COUNT-2:   scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>)
+// CHECK-COUNT-3:     "some_use"
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32>
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+  %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Disjoint leading dim
+      %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Non-overlap trailing dim
+      %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref<?x?xf32>, vector<4xf32>
+      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+      %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we cannot hoist out read-write pairs whose indices are overlapping.
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic
+// CHECK-COUNT-2:   scf.for
+// CHECK-COUNT-2:     vector.transfer_read
+// CHECK-COUNT-2:     vector.transfer_write
+
+func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Overlapping range with the above
+      %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref<?x?xf32>, vector<4xf32>
+      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we can hoist out 2-D read-write pairs whose indices are dynamic values.
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+// CHECK-COUNT-2:   vector.transfer_read
+// CHECK-COUNT-2:   %{{.+}}:2 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>)
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>
+// CHECK-COUNT-2:   vector.transfer_write
+//         CHECK:   return
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32>
+      %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index f43367ab4aeba7d..d0daa53667fab46 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -256,3 +256,106 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
   }
   return
 }
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_same_index
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_same_index(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  vector.transfer_write %v0, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+//   CHECK-LABEL: func @dont_forward_dead_store_dynamic_overlap
+// CHECK-COUNT-2:   vector.transfer_write
+//         CHECK:   vector.transfer_read
+//         CHECK:   scf.for
+//         CHECK:   }
+//         CHECK:   vector.transfer_write
+//         CHECK:   return
+func.func @dont_forward_dead_store_dynamic_overlap(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an overlapping range so we cannot forward.
+  vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_leading_dim
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_non_overlap_leading_dim(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an non-overlapping range so we can forward.
+  vector.transfer_write %v0, %buffer[%i1, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_trailing_dim
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an non-overlapping range so we can forward.
+  vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6e3c3dff759a2ed..2f1631cbdb02e01 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -187,20 +187,26 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
         op->emitOpError("invalid op");
         return WalkResult::skip();
       }
-      FailureOr<bool> equal = failure();
       if (op->hasAttr("compose")) {
-        equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0),
-                                                    op->getOperand(1));
-      } else {
-        equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0),
-                                                   op->getOperand(1));
-      }
-      if (failed(equal)) {
-        op->emitError("could not determine equality");
-      } else if (*equal) {
-        op->emitRemark("equal");
+        FailureOr<int64_t> equal = affine::fullyComposeAndComputeConstantDelta(
+            op->getOperand(0), op->getOperand(1));
+        if (failed(equal)) {
+          op->emitError("could not determine equality");
+        } else if (*equal == 0) {
+          op->emitRemark("equal");
+        } else {
+          op->emitRemark("different");
+        }
       } else {
-        op->emitRemark("different");
+        FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
+            op->getOperand(0), op->getOperand(1));
+        if (failed(equal)) {
+          op->emitError("could not determine equality");
+        } else if (*equal) {
+          op->emitRemark("equal");
+        } else {
+          op->emitRemark("different");
+        }
       }
     }
     return WalkResult::advance();

>From 76fa22ac2d894592e5cddb7630a15285bf849fdf Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 15 Oct 2023 11:19:41 -0700
Subject: [PATCH 2/6] Address comments

---
 mlir/test/Dialect/Vector/vector-transferop-opt.mlir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index d0daa53667fab46..13957af014b89ed 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -271,6 +271,7 @@ func.func @forward_dead_store_dynamic_same_index(
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
   vector.transfer_write %v0, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op reads/writes to the same address so that we can forward.
   %0 = vector.transfer_read %buffer[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
   %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
     %1 = arith.addf %acc, %acc : vector<4xf32>

>From 9f4f4c5aa44bce5bb442524a4be8748f4b950fb7 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sat, 7 Oct 2023 15:54:31 -0700
Subject: [PATCH 3/6] Add ValueBoundsOpInterface as dependency in BUILD

---
 mlir/lib/Dialect/Vector/IR/CMakeLists.txt         | 2 ++
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 9ec919423b3428f..70f3fa8c297d4bc 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRVectorAttributesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
   MLIRArithDialect
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
@@ -22,5 +23,6 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRTensorDialect
+  MLIRValueBoundsOpInterface
   MLIRVectorInterfaces
   )
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index de13e03807e821b..63f9cdafce88b90 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4422,6 +4422,7 @@ cc_library(
     ]),
     includes = ["include"],
     deps = [
+        ":AffineDialect",
         ":ArithDialect",
         ":ArithUtils",
         ":ControlFlowInterfaces",
@@ -4436,6 +4437,7 @@ cc_library(
         ":SideEffectInterfaces",
         ":Support",
         ":TensorDialect",
+        ":ValueBoundsOpInterface",
         ":VectorAttributesIncGen",
         ":VectorDialectIncGen",
         ":VectorInterfaces",

>From 478b11542ba71ee471bf78c3b2a4432765bdd879 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 15 Oct 2023 15:58:44 -0700
Subject: [PATCH 4/6] Update tests

---
 mlir/test/Dialect/Linalg/hoisting.mlir | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 9aaece6a201f799..18a57e6ac53918e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -968,10 +968,10 @@ transform.sequence failures(propagate) {
 // Test that we can hoist out 2-D read-write pairs whose indices are dynamic values.
 
 //   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
-// CHECK-COUNT-2:   vector.transfer_read
-// CHECK-COUNT-2:   %{{.+}}:2 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>)
-// CHECK-COUNT-2:   scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>
-// CHECK-COUNT-2:   vector.transfer_write
+// CHECK-COUNT-3:   vector.transfer_read
+// CHECK-COUNT-2:   %{{.+}}:2 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32)
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>
+// CHECK-COUNT-3:   vector.transfer_write
 //         CHECK:   return
 
 func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
@@ -979,15 +979,19 @@ func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
   %cst = arith.constant 0.0 : f32
   %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
   %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %i4 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 16)>(%i1)
 
   scf.for %i = %lb to %ub step %step {
     scf.for %j = %lb to %ub step %step {
       %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref<?x?xf32>, vector<16x8xf32>
       %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %r2 = vector.transfer_read %buffer[%i0, %i4], %cst: memref<?x?xf32>, vector<16x8xf32>
       %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32>
       %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32>
-      vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32>
+      %u2 = "some_use"(%r2) : (vector<16x8xf32>) -> vector<16x8xf32>
+      vector.transfer_write %u2, %buffer[%i0, %i4] : vector<16x8xf32>, memref<?x?xf32>
       vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref<?x?xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32>
     }
   }
   return

>From be9f1093b2e95f40c76d9fca4abee7fed2fe5e9d Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 15 Oct 2023 15:53:25 -0700
Subject: [PATCH 5/6] Fix delta computation

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0f66d327b26863b..68a5cf209f2fb49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -218,14 +218,13 @@ bool mlir::vector::isDisjointTransferIndices(
         // expression as a fast track.
         FailureOr<int64_t> delta =
             affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
-        if (succeeded(delta) && *delta >= vectorDim)
+        if (succeeded(delta) && std::abs(*delta) >= vectorDim)
           return true;
 
         FailureOr<int64_t> computeDelta =
             ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
         if (succeeded(computeDelta)) {
-          int64_t delta = std::abs(computeDelta.value());
-          if (delta >= vectorDim)
+          if (std::abs(computeDelta.value()) >= vectorDim)
             return true;
         }
       }

>From 290299ed9a40daead660149457b286a8f038d346 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 15 Oct 2023 15:59:59 -0700
Subject: [PATCH 6/6] Address comments

---
 .../include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
index 7c11c13380fadf3..6e617ef40a53d7d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
@@ -19,7 +19,7 @@ namespace affine {
 void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
 
 /// Compute a constant delta of the given two values. Return "failure" if we
-/// cannot determine. / `value1`/`value2` must be index-typed.
+/// cannot determine a constant delta. `value1`/`value2` must be index-typed.
 ///
 /// This function is similar to
 /// `ValueBoundsConstraintSet::computeConstantDistance`. To work around



More information about the Mlir-commits mailing list