[llvm] [AMDGPU][Attributor] Check the validity of a dependent AA before using its value (PR #114165)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 29 19:10:21 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Shilei Tian (shiltian)

<details>
<summary>Changes</summary>

Even though the Attributor framework can invalidate all its dependent AAs after
the current iteration, a dependent AA can still use the worst state of a
depending AA if it doesn't check the state of the depending AA.

---
Full diff: https://github.com/llvm/llvm-project/pull/114165.diff


1 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp (+12-8) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
index 687a7339da379d..6a69b9d2bfc716 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
@@ -358,7 +358,7 @@ struct AAUniformWorkGroupSizeFunction : public AAUniformWorkGroupSize {
 
       const auto *CallerInfo = A.getAAFor<AAUniformWorkGroupSize>(
           *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
-      if (!CallerInfo)
+      if (!CallerInfo || !CallerInfo->isValidState())
         return false;
 
       Change = Change | clampStateAndIndicateChange(this->getState(),
@@ -449,7 +449,8 @@ struct AAAMDAttributesFunction : public AAAMDAttributes {
     // Check for Intrinsics and propagate attributes.
     const AACallEdges *AAEdges = A.getAAFor<AACallEdges>(
         *this, this->getIRPosition(), DepClassTy::REQUIRED);
-    if (!AAEdges || AAEdges->hasNonAsmUnknownCallee())
+    if (!AAEdges || !AAEdges->isValidState() ||
+        AAEdges->hasNonAsmUnknownCallee())
       return indicatePessimisticFixpoint();
 
     bool IsNonEntryFunc = !AMDGPU::isEntryFunctionCC(F->getCallingConv());
@@ -465,7 +466,7 @@ struct AAAMDAttributesFunction : public AAAMDAttributes {
       if (IID == Intrinsic::not_intrinsic) {
         const AAAMDAttributes *AAAMD = A.getAAFor<AAAMDAttributes>(
             *this, IRPosition::function(*Callee), DepClassTy::REQUIRED);
-        if (!AAAMD)
+        if (!AAAMD || !AAAMD->isValidState())
           return indicatePessimisticFixpoint();
         *this &= *AAAMD;
         continue;
@@ -660,7 +661,7 @@ struct AAAMDAttributesFunction : public AAAMDAttributes {
 
       const auto *PointerInfoAA = A.getAAFor<AAPointerInfo>(
           *this, IRPosition::callsite_returned(Call), DepClassTy::REQUIRED);
-      if (!PointerInfoAA)
+      if (!PointerInfoAA || !PointerInfoAA->getState().isValidState())
         return false;
 
       return PointerInfoAA->forallInterferingAccesses(
@@ -717,7 +718,7 @@ struct AAAMDSizeRangeAttribute
 
       const auto *CallerInfo = A.getAAFor<AttributeImpl>(
           *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
-      if (!CallerInfo)
+      if (!CallerInfo || !CallerInfo->isValidState())
         return false;
 
       Change |=
@@ -835,7 +836,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
     auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
 
     if (const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
-            *this, IRPosition::function(*F), DepClassTy::REQUIRED)) {
+            *this, IRPosition::function(*F), DepClassTy::REQUIRED);
+        AssumedGroupSize->isValidState()) {
 
       unsigned Min, Max;
       std::tie(Min, Max) = InfoCache.getWavesPerEU(
@@ -864,7 +866,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
           *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
       const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
           *this, IRPosition::function(*Func), DepClassTy::REQUIRED);
-      if (!CallerInfo || !AssumedGroupSize)
+      if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState() ||
+          !AssumedGroupSize->isValidState())
         return false;
 
       unsigned Min, Max;
@@ -982,7 +985,8 @@ struct AAAMDGPUNoAGPR
       // TODO: Handle callsite attributes
       const auto *CalleeInfo = A.getAAFor<AAAMDGPUNoAGPR>(
           *this, IRPosition::function(*Callee), DepClassTy::REQUIRED);
-      return CalleeInfo && CalleeInfo->getAssumed();
+      return CalleeInfo && CalleeInfo->isValidState() &&
+             CalleeInfo->getAssumed();
     };
 
     bool UsedAssumedInformation = false;

``````````

</details>


https://github.com/llvm/llvm-project/pull/114165


More information about the llvm-commits mailing list