[llvm] [InferAddressSpaces] Fix constant replace to avoid modifying other functions (PR #70611)

Wenju He via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 29 17:43:56 PDT 2023


https://github.com/wenju-he created https://github.com/llvm/llvm-project/pull/70611

A constant value is unique in llvm context. InferAddressSpaces was
replacing its users in other functions as well. This leads to unexpected
behavior in our downstream use case after the pass.

InferAddressSpaces is a function passe, so it shall not modify functions
other than currently processed one.

Co-authored-by: Abhinav Gaba <abhinav.gaba at intel.com>

>From 7c41be75c1ef661e757bfaca8d693b3937df649e Mon Sep 17 00:00:00 2001
From: Wenju He <wenju.he at intel.com>
Date: Mon, 30 Oct 2023 08:36:52 +0800
Subject: [PATCH] [InferAddressSpaces] Fix constant replace to avoid modifying
 other functions

A constant value is unique in llvm context. InferAddressSpaces was
replacing its users in other functions as well. This leads to unexpected
behavior in our downstream use case after the pass.

InferAddressSpaces is a function passe, so it shall not modify functions
other than currently processed one.

Co-authored-by: Abhinav Gaba <abhinav.gaba at intel.com>
---
 llvm/include/llvm/IR/User.h                   | 10 +++++
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 20 +++++++++-
 .../ensure-other-funcs-unchanged.ll           | 40 +++++++++++++++++++
 3 files changed, 69 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Transforms/InferAddressSpaces/ensure-other-funcs-unchanged.ll

diff --git a/llvm/include/llvm/IR/User.h b/llvm/include/llvm/IR/User.h
index a9cf60151e5dc6c..d27d4bf4f5f1e66 100644
--- a/llvm/include/llvm/IR/User.h
+++ b/llvm/include/llvm/IR/User.h
@@ -18,6 +18,7 @@
 #ifndef LLVM_IR_USER_H
 #define LLVM_IR_USER_H
 
+#include "llvm/ADT/GraphTraits.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/IR/Use.h"
@@ -334,6 +335,15 @@ template<> struct simplify_type<User::const_op_iterator> {
   }
 };
 
+template <> struct GraphTraits<User *> {
+  using NodeRef = User *;
+  using ChildIteratorType = Value::user_iterator;
+
+  static NodeRef getEntryNode(NodeRef N) { return N; }
+  static ChildIteratorType child_begin(NodeRef N) { return N->user_begin(); }
+  static ChildIteratorType child_end(NodeRef N) { return N->user_end(); }
+};
+
 } // end namespace llvm
 
 #endif // LLVM_IR_USER_H
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 2da521375c00161..828b1e765cb1af2 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -1166,6 +1166,8 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
   }
 
   SmallVector<Instruction *, 16> DeadInstructions;
+  ValueToValueMapTy VMap;
+  ValueMapper VMapper(VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
 
   // Replaces the uses of the old address expressions with the new ones.
   for (const WeakTrackingVH &WVH : Postorder) {
@@ -1184,7 +1186,18 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
       if (C != Replace) {
         LLVM_DEBUG(dbgs() << "Inserting replacement const cast: " << Replace
                           << ": " << *Replace << '\n');
-        C->replaceAllUsesWith(Replace);
+        VMap[C] = Replace;
+        for (User *U : make_early_inc_range(C->users())) {
+          for (auto It = df_begin(U), E = df_end(U); It != E;) {
+            if (auto *I = dyn_cast<Instruction>(*It)) {
+              if (I->getFunction() == F)
+                VMapper.remapInstruction(*I);
+              It.skipChildren();
+              continue;
+            }
+            ++It;
+          }
+        }
         V = Replace;
       }
     }
@@ -1210,6 +1223,11 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
       // Skip if the current user is the new value itself.
       if (CurUser == NewV)
         continue;
+
+      if (auto *CurUserI = dyn_cast<Instruction>(CurUser);
+          CurUserI && CurUserI->getFunction() != F)
+        continue;
+
       // Handle more complex cases like intrinsic that need to be remangled.
       if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
         if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
diff --git a/llvm/test/Transforms/InferAddressSpaces/ensure-other-funcs-unchanged.ll b/llvm/test/Transforms/InferAddressSpaces/ensure-other-funcs-unchanged.ll
new file mode 100644
index 000000000000000..ae052a69f9ed02d
--- /dev/null
+++ b/llvm/test/Transforms/InferAddressSpaces/ensure-other-funcs-unchanged.ll
@@ -0,0 +1,40 @@
+; RUN: opt -assume-default-is-flat-addrspace -print-module-scope -print-after-all -S -disable-output -passes=infer-address-spaces <%s 2>&1 | FileCheck %s
+
+; CHECK: IR Dump After InferAddressSpacesPass on f2
+
+; Check that after running infer-address-spaces on f2, the redundant addrspace cast %x1 in f2 is gone.
+; CHECK-LABEL: define spir_func void @f2()
+; CHECK:         [[X:%.*]] = addrspacecast ptr addrspace(1) @x to ptr
+; CHECK-NEXT:    call spir_func void @f1(ptr noundef [[X]])
+
+; But it should not affect f3.
+; CHECK-LABEL: define spir_func void @f3()
+; CHECK:         %x1 = addrspacecast ptr addrspacecast (ptr addrspace(1) @x to ptr) to ptr addrspace(1)
+; CHECK-NEXT:    %x2 = addrspacecast ptr addrspace(1) %x1 to ptr
+; CHECK-NEXT:    call spir_func void @f1(ptr noundef %x2)
+
+; Ensure that the pass hasn't run on f3 yet.
+; CHECK: IR Dump After InferAddressSpacesPass on f3
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64"
+
+ at x = addrspace(1) global i32 0, align 4
+
+define spir_func void @f2() {
+entry:
+  %x1 = addrspacecast ptr addrspacecast (ptr addrspace(1) @x to ptr) to ptr addrspace(1)
+  %x2 = addrspacecast ptr addrspace(1) %x1 to ptr
+  call spir_func void @f1(ptr noundef %x2)
+  ret void
+}
+
+define spir_func void @f3() {
+entry:
+  %x1 = addrspacecast ptr addrspacecast (ptr addrspace(1) @x to ptr) to ptr addrspace(1)
+  %x2 = addrspacecast ptr addrspace(1) %x1 to ptr
+  call spir_func void @f1(ptr noundef %x2)
+  ret void
+}
+
+declare spir_func void @f1(ptr noundef)



More information about the llvm-commits mailing list