[llvm] [InstCombine][AMDGPU] Disable PtrReplacer when select has mismatch AS. (PR #98456)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 14 13:04:57 PDT 2024


https://github.com/jofrn updated https://github.com/llvm/llvm-project/pull/98456

>From 703edf4c5668ebe6826c7c9008ff5b537b8af826 Mon Sep 17 00:00:00 2001
From: jofrn <jofernau at amd.com>
Date: Thu, 11 Jul 2024 05:38:07 -0400
Subject: [PATCH 1/3] [InstCombine][AMDGPU] Disable PtrReplacer when select has
 mismatch AS.

A select has two paths that must have matching addrspaces if InstCombine
is to apply PtrReplacer along them. Keep the pointer replacing enabled only if
there is no addrspacecast on the path or if both paths have valid
addrspacecast.
---
 .../InstCombineLoadStoreAlloca.cpp            | 43 +++++++++++++++++
 .../AMDGPU/addrspacecast-cmemptrreplacer.ll   | 48 +++++++++++++++++++
 2 files changed, 91 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 21d5e1dece024..8f39668aef22e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -270,6 +270,8 @@ class PointerReplacer {
     unsigned ToAS = ASC->getDestAddressSpace();
     return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS);
   }
+  bool foundASC(const Value *Op, const SelectInst *SI) const;
+  bool hasConflictingAS(const SelectInst *SI) const;
 
   SmallPtrSet<Instruction *, 32> ValuesToRevisit;
   SmallSetVector<Instruction *, 4> Worklist;
@@ -280,6 +282,45 @@ class PointerReplacer {
 };
 } // end anonymous namespace
 
+/// Return true iff Op is an addrspacecast whose src addrspace
+/// is that of the root and whose dst addrspace is that of
+/// the select.
+bool PointerReplacer::foundASC(const Value *Op, const SelectInst *SI) const {
+  const Instruction *I;
+  while ((I = dyn_cast<Instruction>(Op)) != &Root) {
+    if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) {
+      unsigned SelectOpSrcAS = ASC->getSrcAddressSpace();
+      unsigned RootAS = Root.getType()->getPointerAddressSpace();
+      unsigned SelectOpDstAS = ASC->getDestAddressSpace();
+      unsigned SelectDstAS = SI->getType()->getPointerAddressSpace();
+      return SelectOpSrcAS == RootAS && SelectOpDstAS == SelectDstAS;
+    }
+    if (I && isa<Instruction>(I->getOperand(0)))
+      Op = I->getOperand(0);
+    else if (I)
+      Op = I->getOperand(1);
+    else
+      break;
+  }
+  return false;
+}
+
+/// Return true iff there is only one ASC from root's addrspace
+/// as an operand to the select.
+bool PointerReplacer::hasConflictingAS(const SelectInst *SI) const {
+  auto *TI = SI->getTrueValue();
+  auto *FI = SI->getFalseValue();
+
+  bool FoundTrueASC = foundASC(TI, SI);
+  bool FoundFalseASC = foundASC(FI, SI);
+
+  bool HasConflictingAS = FoundFalseASC ^ FoundTrueASC;
+  LLVM_DEBUG(dbgs() << "HasConflictingAS: " << HasConflictingAS << "{ False: "
+                    << FoundFalseASC << ", True: " << FoundTrueASC << " }: "
+                    << *SI << '\n');
+  return HasConflictingAS;
+}
+
 bool PointerReplacer::collectUsers() {
   if (!collectUsersRecursive(Root))
     return false;
@@ -323,6 +364,8 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) {
       if (!isa<Instruction>(SI->getTrueValue()) ||
           !isa<Instruction>(SI->getFalseValue()))
         return false;
+      if (hasConflictingAS(SI))
+        return false;
 
       if (!isAvailable(cast<Instruction>(SI->getTrueValue())) ||
           !isAvailable(cast<Instruction>(SI->getFalseValue()))) {
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll b/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll
new file mode 100644
index 0000000000000..873b6b10f6465
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll
@@ -0,0 +1,48 @@
+; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=instcombine %s | FileCheck %s
+
+; Variant of select with addrspacecast in one branch on path to root alloca (ic opt is then disabled)
+define void @addrspacecast_true_path(ptr addrspace(4) align 8 byref([2 x i8]) %arg) {
+; CHECK-LABEL: define void @addrspacecast_true_path(ptr addrspace(4) byref([2 x i8]) align 8 %arg) {
+; CHECK-NEXT:  %coerce = alloca [2 x i8], align 8, addrspace(5)
+; CHECK-NEXT:  call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) noundef align 8 dereferenceable(16) %coerce, ptr addrspace(4) noundef align 8 dereferenceable(16) %arg, i64 16, i1 false)
+; CHECK-NEXT:  %load.coerce = load i32, ptr addrspace(5) %coerce, align 8
+; CHECK-NEXT:  %cmp.i = icmp slt i32 %load.coerce, 10
+; CHECK-NEXT:  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 5
+; CHECK-NEXT:  %ret.1 = addrspacecast ptr addrspace(5) %inline_values.i to ptr
+; CHECK-NEXT:  %out_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 6
+; CHECK-NEXT:  %ret.0 = load ptr, ptr addrspace(5) %out_of_line_values.i, align 8
+; CHECK-NEXT:  %retval.0.i = select i1 %cmp.i, ptr %ret.1, ptr %ret.0
+; CHECK-NEXT:  call void @llvm.lifetime.end.p0(i64 8, ptr %retval.0.i)
+; CHECK-NEXT:  ret void
+; CHECK-NEXT:}
+  %coerce = alloca [2 x i8], align 8, addrspace(5)
+  call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 8 %coerce, ptr addrspace(4) align 8 %arg, i64 16, i1 false)
+  %load.coerce = load i32, ptr addrspace(5) %coerce, align 8
+  %cmp.i = icmp slt i32 %load.coerce, 10
+  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 5
+  %ret.1 = addrspacecast ptr addrspace(5) %inline_values.i to ptr
+  %out_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 6
+  %ret.0 = load ptr, ptr addrspace(5) %out_of_line_values.i, align 8
+  %retval.0.i = select i1 %cmp.i, ptr %ret.1, ptr %ret.0
+  call void @llvm.lifetime.end(i64 8, ptr addrspace(0) %retval.0.i)
+  ret void
+}
+
+; Variant of select with valid addrspacecast in both branches on path to root alloca (ic opt remains enabled)
+define void @addrspacecast_both_paths(ptr addrspace(4) align 8 byref([2 x i8]) %arg) {
+; CHECK-LABEL: define void @addrspacecast_both_paths(ptr addrspace(4) byref([2 x i8]) align 8 %arg) {
+; CHECK-NEXT:    ret void
+; CHECK-NEXT:  }
+  %coerce = alloca [2 x i8], align 8, addrspace(5)
+  call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 8 %coerce, ptr addrspace(4) align 8 %arg, i64 16, i1 false)
+  %load.coerce = load i32, ptr addrspace(5) %coerce, align 8
+  %cmp.i = icmp slt i32 %load.coerce, 10
+  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 5
+  %ret.1 = addrspacecast ptr addrspace(5) %inline_values.i to ptr
+  %in_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 6
+  %ret.0 = addrspacecast ptr addrspace(5) %in_of_line_values.i to ptr
+  %retval.0.i = select i1 %cmp.i, ptr %ret.1, ptr %ret.0
+  call void @llvm.lifetime.start(i64 8, ptr addrspace(0) %retval.0.i)
+  ret void
+}
+

>From a2304d14bd0789005a23d6b50f497e911ea6e046 Mon Sep 17 00:00:00 2001
From: jofernau <Joe.Fernau at amd.com>
Date: Fri, 12 Jul 2024 17:27:10 -0400
Subject: [PATCH 2/3] [InstCombine][AMDGPU] changed comments to reflect
 handling any indexing

---
 .../Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp   | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 8f39668aef22e..6b3844ddcd145 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -282,7 +282,7 @@ class PointerReplacer {
 };
 } // end anonymous namespace
 
-/// Return true iff Op is an addrspacecast whose src addrspace
+/// Return true iff addrspacecast is found whose src addrspace
 /// is that of the root and whose dst addrspace is that of
 /// the select.
 bool PointerReplacer::foundASC(const Value *Op, const SelectInst *SI) const {
@@ -305,8 +305,8 @@ bool PointerReplacer::foundASC(const Value *Op, const SelectInst *SI) const {
   return false;
 }
 
-/// Return true iff there is only one ASC from root's addrspace
-/// as an operand to the select.
+/// Return true iff valid ASC is found on one
+/// path from the select to the root.
 bool PointerReplacer::hasConflictingAS(const SelectInst *SI) const {
   auto *TI = SI->getTrueValue();
   auto *FI = SI->getFalseValue();

>From f9c9e3cd69d7a6d96e88ba817b09862816c5ca06 Mon Sep 17 00:00:00 2001
From: jofernau <Joe.Fernau at amd.com>
Date: Sun, 14 Jul 2024 16:01:49 -0400
Subject: [PATCH 3/3] [InstCombine] Recurse ASC downward during instead of
 scanning upward at end

---
 .../InstCombineLoadStoreAlloca.cpp            | 88 +++++++++----------
 .../AMDGPU/addrspacecast-cmemptrreplacer.ll   | 29 +++++-
 2 files changed, 66 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 6b3844ddcd145..e0d540175d030 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -255,7 +255,7 @@ class PointerReplacer {
   void replacePointer(Value *V);
 
 private:
-  bool collectUsersRecursive(Instruction &I);
+  bool collectUsersRecursive(Instruction &I, const AddrSpaceCastInst *ASC = nullptr);
   void replace(Instruction *I);
   Value *getReplacement(Value *I);
   bool isAvailable(Instruction *I) const {
@@ -270,8 +270,7 @@ class PointerReplacer {
     unsigned ToAS = ASC->getDestAddressSpace();
     return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS);
   }
-  bool foundASC(const Value *Op, const SelectInst *SI) const;
-  bool hasConflictingAS(const SelectInst *SI) const;
+  bool hasConflictingAS(const SelectInst *SI, const AddrSpaceCastInst &ASC) const;
 
   SmallPtrSet<Instruction *, 32> ValuesToRevisit;
   SmallSetVector<Instruction *, 4> Worklist;
@@ -279,46 +278,15 @@ class PointerReplacer {
   InstCombinerImpl &IC;
   Instruction &Root;
   unsigned FromAS;
+  using SelectAsc = SmallMapVector<SelectInst *, unsigned, 2>;
+  SelectAsc NumValidAscFound;
 };
 } // end anonymous namespace
 
-/// Return true iff addrspacecast is found whose src addrspace
-/// is that of the root and whose dst addrspace is that of
-/// the select.
-bool PointerReplacer::foundASC(const Value *Op, const SelectInst *SI) const {
-  const Instruction *I;
-  while ((I = dyn_cast<Instruction>(Op)) != &Root) {
-    if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) {
-      unsigned SelectOpSrcAS = ASC->getSrcAddressSpace();
-      unsigned RootAS = Root.getType()->getPointerAddressSpace();
-      unsigned SelectOpDstAS = ASC->getDestAddressSpace();
-      unsigned SelectDstAS = SI->getType()->getPointerAddressSpace();
-      return SelectOpSrcAS == RootAS && SelectOpDstAS == SelectDstAS;
-    }
-    if (I && isa<Instruction>(I->getOperand(0)))
-      Op = I->getOperand(0);
-    else if (I)
-      Op = I->getOperand(1);
-    else
-      break;
-  }
-  return false;
-}
-
-/// Return true iff valid ASC is found on one
-/// path from the select to the root.
-bool PointerReplacer::hasConflictingAS(const SelectInst *SI) const {
-  auto *TI = SI->getTrueValue();
-  auto *FI = SI->getFalseValue();
-
-  bool FoundTrueASC = foundASC(TI, SI);
-  bool FoundFalseASC = foundASC(FI, SI);
-
-  bool HasConflictingAS = FoundFalseASC ^ FoundTrueASC;
-  LLVM_DEBUG(dbgs() << "HasConflictingAS: " << HasConflictingAS << "{ False: "
-                    << FoundFalseASC << ", True: " << FoundTrueASC << " }: "
-                    << *SI << '\n');
-  return HasConflictingAS;
+bool PointerReplacer::hasConflictingAS(const SelectInst *SI, const AddrSpaceCastInst &ASC) const {
+  unsigned SelectOpDstAS = ASC.getDestAddressSpace();
+  unsigned SelectDstAS = SI->getType()->getPointerAddressSpace();
+  return SelectOpDstAS != SelectDstAS;
 }
 
 bool PointerReplacer::collectUsers() {
@@ -334,9 +302,20 @@ bool PointerReplacer::collectUsers() {
   return true;
 }
 
-bool PointerReplacer::collectUsersRecursive(Instruction &I) {
+bool PointerReplacer::collectUsersRecursive(Instruction &I, const AddrSpaceCastInst *ASC) {
   for (auto *U : I.users()) {
     auto *Inst = cast<Instruction>(&*U);
+
+    if (ASC != nullptr)
+      if (auto *SI = dyn_cast<SelectInst>(Inst))
+        if (!hasConflictingAS(SI, *ASC)) {
+          SelectAsc::iterator it;
+          if ((it = NumValidAscFound.find(SI)) == NumValidAscFound.end())
+            NumValidAscFound.insert(std::make_pair(SI, 1));
+          else
+            ++it->second;
+        }
+
     if (auto *Load = dyn_cast<LoadInst>(Inst)) {
       if (Load->isVolatile())
         return false;
@@ -358,34 +337,49 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) {
       }
 
       Worklist.insert(PHI);
-      if (!collectUsersRecursive(*PHI))
+      if (!collectUsersRecursive(*PHI, ASC))
         return false;
     } else if (auto *SI = dyn_cast<SelectInst>(Inst)) {
       if (!isa<Instruction>(SI->getTrueValue()) ||
           !isa<Instruction>(SI->getFalseValue()))
         return false;
-      if (hasConflictingAS(SI))
-        return false;
 
       if (!isAvailable(cast<Instruction>(SI->getTrueValue())) ||
           !isAvailable(cast<Instruction>(SI->getFalseValue()))) {
         ValuesToRevisit.insert(Inst);
         continue;
       }
+
+      // If only one path has addrspacecast, then transformation is illegal.
+      SelectAsc::iterator it;
+      if ((it = NumValidAscFound.find(cast<SelectInst>(Inst))) != NumValidAscFound.end()) {
+        assert(it->second <= 2 && "select should have 2 operands");
+        if (it->second == 1)
+          return false;
+      }
+
       Worklist.insert(SI);
-      if (!collectUsersRecursive(*SI))
+      if (!collectUsersRecursive(*SI, ASC))
         return false;
     } else if (isa<GetElementPtrInst>(Inst)) {
       Worklist.insert(Inst);
-      if (!collectUsersRecursive(*Inst))
+      if (!collectUsersRecursive(*Inst, ASC))
         return false;
     } else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) {
       if (MI->isVolatile())
         return false;
       Worklist.insert(Inst);
     } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) {
+      auto *NextASC = cast<AddrSpaceCastInst>(Inst);
+      if (ASC == nullptr) {
+        if (Root.getType()->getPointerAddressSpace() != NextASC->getSrcAddressSpace())
+          return false;
+      } else {
+        if (ASC->getDestAddressSpace() != NextASC->getSrcAddressSpace())
+          return false;
+      }
       Worklist.insert(Inst);
-      if (!collectUsersRecursive(*Inst))
+      if (!collectUsersRecursive(*Inst, NextASC))
         return false;
     } else if (Inst->isLifetimeStartOrEnd()) {
       continue;
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll b/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll
index 873b6b10f6465..0b8ff71a26cc2 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/addrspacecast-cmemptrreplacer.ll
@@ -19,9 +19,9 @@ define void @addrspacecast_true_path(ptr addrspace(4) align 8 byref([2 x i8]) %a
   call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 8 %coerce, ptr addrspace(4) align 8 %arg, i64 16, i1 false)
   %load.coerce = load i32, ptr addrspace(5) %coerce, align 8
   %cmp.i = icmp slt i32 %load.coerce, 10
-  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 5
+  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 5
   %ret.1 = addrspacecast ptr addrspace(5) %inline_values.i to ptr
-  %out_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 6
+  %out_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 6
   %ret.0 = load ptr, ptr addrspace(5) %out_of_line_values.i, align 8
   %retval.0.i = select i1 %cmp.i, ptr %ret.1, ptr %ret.0
   call void @llvm.lifetime.end(i64 8, ptr addrspace(0) %retval.0.i)
@@ -37,12 +37,33 @@ define void @addrspacecast_both_paths(ptr addrspace(4) align 8 byref([2 x i8]) %
   call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 8 %coerce, ptr addrspace(4) align 8 %arg, i64 16, i1 false)
   %load.coerce = load i32, ptr addrspace(5) %coerce, align 8
   %cmp.i = icmp slt i32 %load.coerce, 10
-  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 5
+  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 0
   %ret.1 = addrspacecast ptr addrspace(5) %inline_values.i to ptr
-  %in_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 6
+  %in_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 1
   %ret.0 = addrspacecast ptr addrspace(5) %in_of_line_values.i to ptr
   %retval.0.i = select i1 %cmp.i, ptr %ret.1, ptr %ret.0
   call void @llvm.lifetime.start(i64 8, ptr addrspace(0) %retval.0.i)
   ret void
 }
 
+; Variant of select with multiple valid addrspacecast in both paths
+define void @addrspacecast_multi_asc(ptr addrspace(4) align 8 byref([2 x i8]) %arg) {
+; CHECK-LABEL: define void @addrspacecast_multi_asc(ptr addrspace(4) byref([2 x i8]) align 8 %arg) {
+; CHECK-NEXT:    ret void
+; CHECK-NEXT:  }
+  %coerce = alloca [2 x i8], align 8, addrspace(5)
+  call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 8 %coerce, ptr addrspace(4) align 8 %arg, i64 16, i1 false)
+  %load.coerce = load i32, ptr addrspace(5) %coerce, align 8
+  %cmp.i = icmp slt i32 %load.coerce, 10
+  %inline_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 0
+  %tmp.1 = addrspacecast ptr addrspace(5) %inline_values.i to ptr addrspace(3)
+  %ret.1 = addrspacecast ptr addrspace(3) %tmp.1 to ptr
+  %in_of_line_values.i = getelementptr inbounds i8, ptr addrspace(5) %coerce, i32 1
+  %tmp.a.0 = addrspacecast ptr addrspace(5) %in_of_line_values.i to ptr addrspace(70)
+  %tmp.b.0 = addrspacecast ptr addrspace(70) %tmp.a.0 to ptr addrspace(50)
+  %ret.0 = addrspacecast ptr addrspace(50) %tmp.b.0 to ptr
+  %retval.0.i = select i1 %cmp.i, ptr %ret.1, ptr %ret.0
+  call void @llvm.lifetime.start(i64 8, ptr addrspace(0) %retval.0.i)
+  ret void
+}
+



More information about the llvm-commits mailing list