[llvm] [MemCpyOpt] Fix the invalid code modification for GEP (PR #68479)

Kai Yan via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 8 02:37:33 PDT 2023


https://github.com/kaiyan96 updated https://github.com/llvm/llvm-project/pull/68479

>From a6cd812f13ccd266a185f38f03f5767a423104cb Mon Sep 17 00:00:00 2001
From: aklkaiyan <aklkaiyan at tencent.com>
Date: Sat, 7 Oct 2023 10:53:27 +0800
Subject: [PATCH] [MemCpyOpt] Fix the invalid code modification for GEP

Apply a rollback for the GEP modification in the performCallSlotOption function
after the optimization has failed.
---
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp | 38 +++++++++++++-
 .../MemCpyOpt/memcpy-gep-position-guard.ll    | 51 +++++++++++++++++++
 2 files changed, 87 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/MemCpyOpt/memcpy-gep-position-guard.ll

diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 783ef57805610b9..69401f5b2a5b290 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -75,6 +75,34 @@ STATISTIC(NumCallSlot, "Number of call slot optimizations performed");
 STATISTIC(NumStackMove, "Number of stack-move optimizations performed");
 
 namespace {
+/// This class can guard the position of an instruction if you make any
+/// modifications to it and will, by default, rollback the position of this
+/// instruction.
+/// This class is not suitable for use with 'br' instructions.
+class InstructionPositionGuard {
+  bool DisableGuard = false;
+
+public:
+  InstructionPositionGuard() = default;
+
+  ~InstructionPositionGuard() {
+    if (!DisableGuard && OriginalPos && Inst->getParent())
+      Inst->moveBefore(&*OriginalPos);
+  }
+
+  void setInst(Instruction *I) {
+    // We take the node after to record the position, since inst br
+    // is always the last node of BB
+    Inst = I;
+    OriginalPos = I->getNextNode();
+  }
+
+  void disableGuard() { DisableGuard = true; }
+
+private:
+  Instruction *Inst = nullptr;
+  Instruction *OriginalPos = nullptr;
+};
 
 /// Represents a range of memset'd bytes with the ByteVal value.
 /// This allows us to analyze stores like:
@@ -1041,15 +1069,20 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
     }
   }
 
+  /// A PositionGuard for GEP modification in function performCallSlotOptzn
+  /// If an optimization is failed, code should not be changed
+  InstructionPositionGuard GEPGuard;
+
   // Since we're changing the parameter to the callsite, we need to make sure
   // that what would be the new parameter dominates the callsite.
   if (!DT->dominates(cpyDest, C)) {
     // Support moving a constant index GEP before the call.
     auto *GEP = dyn_cast<GetElementPtrInst>(cpyDest);
     if (GEP && GEP->hasAllConstantIndices() &&
-        DT->dominates(GEP->getPointerOperand(), C))
+        DT->dominates(GEP->getPointerOperand(), C)) {
+      GEPGuard.setInst(dyn_cast<Instruction>(cpyDest));
       GEP->moveBefore(C);
-    else
+    } else
       return false;
   }
 
@@ -1112,6 +1145,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
     combineAAMetadata(C, cpyStore);
 
   ++NumCallSlot;
+  GEPGuard.disableGuard();
   return true;
 }
 
diff --git a/llvm/test/Transforms/MemCpyOpt/memcpy-gep-position-guard.ll b/llvm/test/Transforms/MemCpyOpt/memcpy-gep-position-guard.ll
new file mode 100644
index 000000000000000..160e87c94507c9c
--- /dev/null
+++ b/llvm/test/Transforms/MemCpyOpt/memcpy-gep-position-guard.ll
@@ -0,0 +1,51 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=memcpyopt < %s -verify-memoryssa | FileCheck %s
+
+; ModuleID = 'memcpy_invalid_modify.ll'
+source_filename = "memcpy_invalid_modify.ll"
+
+%struct.MaskedType = type { i8, i8 }
+
+declare void @llvm.lifetime.start.p0(i64, i8* nocapture) #0
+declare void @llvm.lifetime.end.p0(i64, i8* nocapture) #0
+declare void @MaskedFunction(%struct.MaskedType*, i8*)
+
+define i8 @test_gep_position_guard_enable(i8* noundef %in0, %struct.MaskedType* noundef %in1) {
+; CHECK-LABEL: define i8 @test_gep_position_guard_enable(
+; CHECK:    {{%.*}} = alloca [[STRUCT_MASKEDTYPE:%.*]], align 4
+; CHECK:    call void @MaskedFunction(
+; CHECK:    {{%.*}}= load i8
+; CHECK:    {{%.*}} = getelementptr inbounds [[STRUCT_MASKEDTYPE]]
+;
+entry:
+  %funcAlloc = alloca %struct.MaskedType, align 4
+  %ptrAlloc = alloca i8, align 1
+  call void @llvm.lifetime.start.p0(i64 4, i8* %ptrAlloc) #0
+  %addrspaceCast = addrspacecast i8* %ptrAlloc to i8 addrspace(1)*
+  call void @MaskedFunction(%struct.MaskedType* noundef %in1, i8 addrspace(1)* noundef %addrspaceCast)
+  %load1 = load i8, i8* %ptrAlloc, align 1
+  call void @llvm.lifetime.end.p0(i64 4, i8* %ptrAlloc) #0
+  %getElemPtr1 = getelementptr inbounds %struct.MaskedType, %struct.MaskedType* %funcAlloc, i32 0, i32 1
+  store i8 %load1, i8* %getElemPtr1, align 1
+  ret i8 0
+}
+
+define i8 @test_gep_position_guard_disable(i8* noundef %in0, %struct.MaskedType* noundef %in1) {
+; CHECK-LABEL: define i8 @test_gep_position_guard_disable(
+; CHECK:    {{%.*}} = alloca [[STRUCT_MASKEDTYPE:%.*]], align 4
+; CHECK:    {{%.*}} = getelementptr inbounds [[STRUCT_MASKEDTYPE]]
+; CHECK:    call void @MaskedFunction(
+; CHECK-NOT:    {{%.*}}= load i8
+;
+entry:
+  %funcAlloc = alloca %struct.MaskedType, align 4
+  %ptrAlloc = alloca i8, align 1
+  call void @llvm.lifetime.start.p0(i64 4, i8* %ptrAlloc) #0
+  call void @MaskedFunction(%struct.MaskedType* noundef %in1, i8* noundef %ptrAlloc)
+  %load1 = load i8, i8* %ptrAlloc, align 1
+  call void @llvm.lifetime.end.p0(i64 4, i8* %ptrAlloc) #0
+  %getElemPtr1 = getelementptr inbounds %struct.MaskedType, %struct.MaskedType* %funcAlloc, i32 0, i32 1
+  store i8 %load1, i8* %getElemPtr1, align 1
+  ret i8 0
+}
+



More information about the llvm-commits mailing list