[llvm] [LAA] Support different strides & non constant dep distances using SCEV. (PR #88039)

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 16 05:47:16 PDT 2024


================
@@ -2028,68 +2029,105 @@ MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
   if (std::holds_alternative<Dependence::DepType>(Res))
     return std::get<Dependence::DepType>(Res);
 
-  const auto &[Dist, Stride, TypeByteSize, AIsWrite, BIsWrite] =
+  const auto &[Dist, StrideA, StrideB, TypeByteSize, AIsWrite, BIsWrite] =
       std::get<DepDistanceStrideAndSizeInfo>(Res);
   bool HasSameSize = TypeByteSize > 0;
 
+  uint64_t CommonStride = StrideA == StrideB ? StrideA : 0;
+  if (isa<SCEVCouldNotCompute>(Dist)) {
+    FoundNonConstantDistanceDependence |= CommonStride;
+    LLVM_DEBUG(dbgs() << "LAA: Dependence because of uncomputable distance.\n");
+    return Dependence::Unknown;
+  }
+
   ScalarEvolution &SE = *PSE.getSE();
   auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout();
-  if (!isa<SCEVCouldNotCompute>(Dist) && HasSameSize &&
+  if (HasSameSize && CommonStride &&
       isSafeDependenceDistance(DL, SE, *(PSE.getBackedgeTakenCount()), *Dist,
-                               Stride, TypeByteSize))
+                               CommonStride, TypeByteSize))
     return Dependence::NoDep;
 
   const SCEVConstant *C = dyn_cast<SCEVConstant>(Dist);
-  if (!C) {
-    LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n");
-    FoundNonConstantDistanceDependence = true;
-    return Dependence::Unknown;
-  }
-
-  const APInt &Val = C->getAPInt();
-  int64_t Distance = Val.getSExtValue();
 
   // Attempt to prove strided accesses independent.
-  if (std::abs(Distance) > 0 && Stride > 1 && HasSameSize &&
-      areStridedAccessesIndependent(std::abs(Distance), Stride, TypeByteSize)) {
-    LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n");
-    return Dependence::NoDep;
+  if (C) {
+    const APInt &Val = C->getAPInt();
+    int64_t Distance = Val.getSExtValue();
+
+    if (std::abs(Distance) > 0 && CommonStride > 1 && HasSameSize &&
+        areStridedAccessesIndependent(std::abs(Distance), CommonStride,
+                                      TypeByteSize)) {
+      LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n");
+      return Dependence::NoDep;
+    }
+
+    // Write to the same location with the same size.
+    if (C->isZero()) {
+      if (HasSameSize)
+        return Dependence::Forward;
+      LLVM_DEBUG(
+          dbgs()
+          << "LAA: Zero dependence difference but different type sizes\n");
+      return Dependence::Unknown;
+    }
   }
 
   // Negative distances are not plausible dependencies.
-  if (Val.isNegative()) {
+  if (SE.isKnownNonPositive(Dist)) {
+    if (!SE.isKnownNegative(Dist) && !HasSameSize) {
+      LLVM_DEBUG(dbgs() << "LAA: possibly zero dependence difference but "
+                           "different type sizes\n");
+      return Dependence::Unknown;
+    }
+
     bool IsTrueDataDependence = (AIsWrite && !BIsWrite);
     // There is no need to update MaxSafeVectorWidthInBits after call to
     // couldPreventStoreLoadForward, even if it changed MinDepDistBytes,
     // since a forward dependency will allow vectorization using any width.
-    if (IsTrueDataDependence && EnableForwardingConflictDetection &&
-        (!HasSameSize || couldPreventStoreLoadForward(Val.abs().getZExtValue(),
-                                                      TypeByteSize))) {
-      LLVM_DEBUG(dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
-      return Dependence::ForwardButPreventsForwarding;
+    if (IsTrueDataDependence && EnableForwardingConflictDetection) {
+      if (!C) {
+        FoundNonConstantDistanceDependence |= CommonStride;
+        return Dependence::Unknown;
+      }
+      if (!HasSameSize ||
+          couldPreventStoreLoadForward(C->getAPInt().abs().getZExtValue(),
+                                       TypeByteSize)) {
+        LLVM_DEBUG(
+            dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
+        return Dependence::ForwardButPreventsForwarding;
+      }
     }
 
     LLVM_DEBUG(dbgs() << "LAA: Dependence is negative\n");
     return Dependence::Forward;
   }
 
-  // Write to the same location with the same size.
-  if (Val == 0) {
-    if (HasSameSize)
-      return Dependence::Forward;
-    LLVM_DEBUG(
-        dbgs() << "LAA: Zero dependence difference but different type sizes\n");
+  if (!SE.isKnownPositive(Dist)) {
+    FoundNonConstantDistanceDependence |= !C && CommonStride;
     return Dependence::Unknown;
   }
 
-  assert(Val.isStrictlyPositive() && "Expect a positive value");
-
   if (!HasSameSize) {
     LLVM_DEBUG(dbgs() << "LAA: ReadWrite-Write positive dependency with "
                          "different type sizes\n");
+    FoundNonConstantDistanceDependence |= !C && CommonStride;
     return Dependence::Unknown;
   }
 
+  if (!C) {
+    LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n");
+    FoundNonConstantDistanceDependence |= CommonStride;
+    return Dependence::Unknown;
+  }
+
+  // The logic below currently only supports StrideA ==  StrideB, i.e. there's a
+  // common stride.
+  if (!CommonStride)
+    return Dependence::Unknown;
----------------
Meinersbur wrote:

Many checks above also test (`CommonStride > 1`). Move them below this check?

https://github.com/llvm/llvm-project/pull/88039


More information about the llvm-commits mailing list