[llvm] [MemCpyOpt] Fix the invalid code modification for GEP (PR #68479)

Kai Yan via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 7 04:33:13 PDT 2023


https://github.com/kaiyan96 created https://github.com/llvm/llvm-project/pull/68479

Apply a rollback for the GEP modification in the performCallSlotOption function
after the optimization has failed.


>From ff0f90f9accc4ef859feb8841d1a91fa4109afbc Mon Sep 17 00:00:00 2001
From: aklkaiyan <aklkaiyan at tencent.com>
Date: Sat, 7 Oct 2023 10:53:27 +0800
Subject: [PATCH] [MemCpyOpt] Fix the invalid code modification for GEP

Apply a rollback for the GEP modification in the performCallSlotOption function
after the optimization has failed.
---
 .../llvm/Transforms/Scalar/MemCpyOptimizer.h  |  1 +
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp | 35 ++++++++++++++---
 .../MemCpyOpt/memcpy_invalid_modify.ll        | 39 +++++++++++++++++++
 3 files changed, 69 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/Transforms/MemCpyOpt/memcpy_invalid_modify.ll

diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
index 6c809bc881d050d..5619c89e5fd8b79 100644
--- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
+++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
@@ -87,6 +87,7 @@ class MemCpyOptPass : public PassInfoMixin<MemCpyOptPass> {
 
   void eraseInstruction(Instruction *I);
   bool iterateOnFunction(Function &F);
+  void rollbackGEPChange(Value *GepValue, Instruction *GEPPosRecord);
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 783ef57805610b9..fb12783629dba49 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -1043,13 +1043,17 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
 
   // Since we're changing the parameter to the callsite, we need to make sure
   // that what would be the new parameter dominates the callsite.
+  Instruction *GEPPosRecord = nullptr;
   if (!DT->dominates(cpyDest, C)) {
     // Support moving a constant index GEP before the call.
     auto *GEP = dyn_cast<GetElementPtrInst>(cpyDest);
     if (GEP && GEP->hasAllConstantIndices() &&
-        DT->dominates(GEP->getPointerOperand(), C))
+        DT->dominates(GEP->getPointerOperand(), C)) {
+      // We take the node after GEP to record the position, since inst br
+      // is always the last node of BB
+      GEPPosRecord = dyn_cast<Instruction>(cpyDest)->getNextNode();
       GEP->moveBefore(C);
-    else
+    } else
       return false;
   }
 
@@ -1062,19 +1066,25 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
   // If necessary, perform additional analysis.
   if (isModOrRefSet(MR))
     MR = BAA.callCapturesBefore(C, DestWithSrcSize, DT);
-  if (isModOrRefSet(MR))
+  if (isModOrRefSet(MR)) {
+    rollbackGEPChange(cpyDest, GEPPosRecord);
     return false;
+  }
 
   // We can't create address space casts here because we don't know if they're
   // safe for the target.
   if (cpySrc->getType()->getPointerAddressSpace() !=
-      cpyDest->getType()->getPointerAddressSpace())
+      cpyDest->getType()->getPointerAddressSpace()) {
+    rollbackGEPChange(cpyDest, GEPPosRecord);
     return false;
+  }
   for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI)
     if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc &&
         cpySrc->getType()->getPointerAddressSpace() !=
-            C->getArgOperand(ArgI)->getType()->getPointerAddressSpace())
+            C->getArgOperand(ArgI)->getType()->getPointerAddressSpace()) {
+      rollbackGEPChange(cpyDest, GEPPosRecord);
       return false;
+    }
 
   // All the checks have passed, so do the transformation.
   bool changedArgument = false;
@@ -1092,8 +1102,10 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
                                    Dest->getName(), C));
     }
 
-  if (!changedArgument)
+  if (!changedArgument) {
+    rollbackGEPChange(cpyDest, GEPPosRecord);
     return false;
+  }
 
   // If the destination wasn't sufficiently aligned then increase its alignment.
   if (!isDestSufficientlyAligned) {
@@ -1115,6 +1127,17 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
   return true;
 }
 
+/// A roll back function for GEP modification in function performCallSlotOptzn
+/// If an optimization is failed, code should not be changed
+void MemCpyOptPass::rollbackGEPChange(Value *GepValue,
+                                      Instruction *GEPPosRecord) {
+  if (GEPPosRecord == nullptr)
+    return;
+
+  auto *GEP = dyn_cast<GetElementPtrInst>(GepValue);
+  GEP->moveBefore(GEPPosRecord);
+}
+
 /// We've found that the (upward scanning) memory dependence of memcpy 'M' is
 /// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can.
 bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
diff --git a/llvm/test/Transforms/MemCpyOpt/memcpy_invalid_modify.ll b/llvm/test/Transforms/MemCpyOpt/memcpy_invalid_modify.ll
new file mode 100644
index 000000000000000..2858f4227c5b011
--- /dev/null
+++ b/llvm/test/Transforms/MemCpyOpt/memcpy_invalid_modify.ll
@@ -0,0 +1,39 @@
+; RUN: opt -S -passes=memcpyopt < %s -verify-memoryssa | FileCheck %s
+
+; ModuleID = 'memcpy_invalid_modify.ll'
+source_filename = "memcpy_invalid_modify.ll"
+
+%struct.MaskedType = type { i32, i32 }
+
+ at .str = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
+
+declare void @llvm.lifetime.start.p0(i64, i8* nocapture) #0
+declare void @llvm.lifetime.end.p0(i64, i8* nocapture) #0
+declare void @MaskedFunction(%struct.MaskedType*, i32 addrspace(1)*)
+
+define i32 @test_gep_modified(ptr noundef %0, ptr noundef %1) {
+; CHECK-NOT: LLVM ERROR: Function @test_gep_modified changed by MemCpyOptPass without invalidating analyses
+; CHECK:   call void @MaskedFunction
+; CHECK:   %7 = getelementptr inbounds %struct.MaskedType
+
+entry:
+  %2 = alloca %struct.MaskedType, align 4
+  %3 = alloca i32, align 4
+  br label %4
+
+4:                                                ; preds = %4, %entry
+  call void @llvm.lifetime.start.p0(i64 4, ptr %3) #0
+  %5 = addrspacecast i32* %3 to i32 addrspace(1)*
+  call void @MaskedFunction(%struct.MaskedType* noundef %1, i32 addrspace(1)* noundef %5)
+  %6 = load i32, ptr %3, align 4
+  call void @llvm.lifetime.end.p0(i64 4, ptr %3) #0
+  %7 = getelementptr inbounds %"struct.MaskedType", ptr %2, i32 0, i32 1
+  store i32 %6, ptr %7, align 4
+  %8 = load i32, ptr %1, align 4
+  %cond = icmp eq i32 %8, 0
+  br i1 %cond, label %4, label %9
+9:                                                ; preds = %4
+  ret i32 0
+}
+
+attributes #0 = { nounwind }
\ No newline at end of file



More information about the llvm-commits mailing list