[llvm] LAA: generalize strides over unequal type sizes (PR #108088)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 17 05:36:36 PST 2024
================
@@ -1972,30 +1969,68 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
return MemoryDepChecker::Dependence::IndirectUnsafe;
}
- int64_t StrideAPtrInt = *StrideAPtr;
- int64_t StrideBPtrInt = *StrideBPtr;
- LLVM_DEBUG(dbgs() << "LAA: Src induction step: " << StrideAPtrInt
- << " Sink induction step: " << StrideBPtrInt << "\n");
+ LLVM_DEBUG(dbgs() << "LAA: Src induction step: " << *StrideAPtr
+ << " Sink induction step: " << *StrideBPtr << "\n");
+
+ // Note that store size is different from alloc size, which is dependent on
+ // store size. We use the former for checking illegal cases, and the latter
+ // for scaling strides.
+ TypeSize AStoreSz = DL.getTypeStoreSize(ATy),
+ BStoreSz = DL.getTypeStoreSize(BTy);
+
+ // When the distance is zero, we're reading/writing the same memory location:
+ // check that the store sizes are equal. Otherwise, fail with an unknown
+ // dependence for which we should not generate runtime checks.
+ if (Dist->isZero() && AStoreSz != BStoreSz)
+ return MemoryDepChecker::Dependence::Unknown;
+
+ // We can't get get a uint64_t for the AllocSize if either of the store sizes
+ // are scalable.
+ if (AStoreSz.isScalable() || BStoreSz.isScalable())
+ return MemoryDepChecker::Dependence::Unknown;
+
+ // The TypeByteSize is used to scale Distance and VF. In these contexts, the
+ // only size that matters is the size of the Sink.
+ uint64_t ASz = alignTo(AStoreSz, DL.getABITypeAlign(ATy).value()),
+ TypeByteSize = alignTo(BStoreSz, DL.getABITypeAlign(BTy).value());
+
+ // We scale the strides by the alloc-type-sizes, so we can check that the
+ // common distance is equal when ASz != BSz.
+ int64_t StrideAScaled = *StrideAPtr * ASz;
+ int64_t StrideBScaled = *StrideBPtr * TypeByteSize;
+
// At least Src or Sink are loop invariant and the other is strided or
// invariant. We can generate a runtime check to disambiguate the accesses.
- if (!StrideAPtrInt || !StrideBPtrInt)
+ if (!StrideAScaled || !StrideBScaled)
return MemoryDepChecker::Dependence::Unknown;
// Both Src and Sink have a constant stride, check if they are in the same
// direction.
- if ((StrideAPtrInt > 0) != (StrideBPtrInt > 0)) {
+ if ((StrideAScaled > 0) != (StrideBScaled > 0)) {
LLVM_DEBUG(
dbgs() << "Pointer access with strides in different directions\n");
return MemoryDepChecker::Dependence::Unknown;
}
- uint64_t TypeByteSize = DL.getTypeAllocSize(ATy);
- bool HasSameSize =
- DL.getTypeStoreSizeInBits(ATy) == DL.getTypeStoreSizeInBits(BTy);
- if (!HasSameSize)
- TypeByteSize = 0;
- return DepDistanceStrideAndSizeInfo(Dist, std::abs(StrideAPtrInt),
- std::abs(StrideBPtrInt), TypeByteSize,
+ StrideAScaled = std::abs(StrideAScaled);
+ StrideBScaled = std::abs(StrideBScaled);
+
+ // MaxStride is the max of the scaled strides, as expected.
+ uint64_t MaxStride = std::max(StrideAScaled, StrideBScaled);
+
+ // CommonStride is set if both scaled strides are equal.
+ std::optional<uint64_t> CommonStride;
+ if (StrideAScaled == StrideBScaled)
+ CommonStride = StrideAScaled;
+
+ // TODO: Historically, we don't retry with runtime checks unless the unscaled
+ // strides are the same, but this doesn't make sense. Fix this once the
+ // condition for runtime checks in isDependent is fixed.
+ bool ShouldRetryWithRuntimeCheck =
+ std::abs(*StrideAPtr) == std::abs(*StrideBPtr);
+
+ return DepDistanceStrideAndSizeInfo(Dist, MaxStride, CommonStride,
----------------
david-arm wrote:
It looks like you could reduce the complexity of this patch with an initial NFC refactoring PR that changes the constructor of `DepDistanceStrideAndSizeInfo` to take the MaxStride, CommonStride and ShouldRetryWithRuntimeCheck parameters. Essentially, it's just pushing some of the work into constructing the `DepDistanceStrideAndSizeInfo` object rather than in `MemoryDepChecker::isDependent`, which seems a sensible thing to do given we could call `isDependent` many times on the same object.
https://github.com/llvm/llvm-project/pull/108088
More information about the llvm-commits
mailing list