[llvm] 2ddcf72 - [InstCombine] Perform memset -> load forwarding

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 3 08:05:14 PDT 2022


Author: Nikita Popov
Date: 2022-11-03T16:03:57+01:00
New Revision: 2ddcf721a0e8dafec5196001b2472480f0011887

URL: https://github.com/llvm/llvm-project/commit/2ddcf721a0e8dafec5196001b2472480f0011887
DIFF: https://github.com/llvm/llvm-project/commit/2ddcf721a0e8dafec5196001b2472480f0011887.diff

LOG: [InstCombine] Perform memset -> load forwarding

InstCombine does some basic store to load forwarding. One case it
currently misses is the case where the store is actually a memset.
This patch adds support for this case. This is a minimal
implementation that only handles a load at the memset base address,
without an offset.

GVN is already capable of performing this optimization. Having it
in InstCombine can help with phase ordering issues, similar to the
existing store to load forwarding.

Differential Revision: https://reviews.llvm.org/D137323

Added: 
    

Modified: 
    llvm/lib/Analysis/Loads.cpp
    llvm/test/Transforms/InstCombine/load-store-forward.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 9eff2b161185..93faefa947a3 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -513,6 +513,39 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
         return ConstantFoldLoadFromConst(C, AccessTy, DL);
   }
 
+  if (auto *MSI = dyn_cast<MemSetInst>(Inst)) {
+    // Don't forward from (non-atomic) memset to atomic load.
+    if (AtLeastAtomic)
+      return nullptr;
+
+    // Only handle constant memsets.
+    auto *Val = dyn_cast<ConstantInt>(MSI->getValue());
+    auto *Len = dyn_cast<ConstantInt>(MSI->getLength());
+    if (!Val || !Len)
+      return nullptr;
+
+    // TODO: Handle offsets.
+    Value *Dst = MSI->getDest();
+    if (!AreEquivalentAddressValues(Dst, Ptr))
+      return nullptr;
+
+    if (IsLoadCSE)
+      *IsLoadCSE = false;
+
+    // Make sure the read bytes are contained in the memset.
+    TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy);
+    if (LoadSize.isScalable() ||
+        (Len->getValue() * 8).ult(LoadSize.getFixedSize()))
+      return nullptr;
+
+    APInt Splat = APInt::getSplat(LoadSize.getFixedSize(), Val->getValue());
+    ConstantInt *SplatC = ConstantInt::get(MSI->getContext(), Splat);
+    if (CastInst::isBitOrNoopPointerCastable(SplatC->getType(), AccessTy, DL))
+      return SplatC;
+
+    return nullptr;
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/load-store-forward.ll b/llvm/test/Transforms/InstCombine/load-store-forward.ll
index d90af935c65e..5a847cd68db8 100644
--- a/llvm/test/Transforms/InstCombine/load-store-forward.ll
+++ b/llvm/test/Transforms/InstCombine/load-store-forward.ll
@@ -257,8 +257,7 @@ define i1 @load_i1_store_i8(ptr %a) {
 define i32 @load_after_memset_0(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_0(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load i32, ptr [[A]], align 4
-; CHECK-NEXT:    ret i32 [[V]]
+; CHECK-NEXT:    ret i32 0
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
   %v = load i32, ptr %a
@@ -268,8 +267,7 @@ define i32 @load_after_memset_0(ptr %a) {
 define float @load_after_memset_0_float(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_0_float(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load float, ptr [[A]], align 4
-; CHECK-NEXT:    ret float [[V]]
+; CHECK-NEXT:    ret float 0.000000e+00
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
   %v = load float, ptr %a
@@ -279,8 +277,7 @@ define float @load_after_memset_0_float(ptr %a) {
 define i27 @load_after_memset_0_non_byte_sized(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_0_non_byte_sized(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load i27, ptr [[A]], align 4
-; CHECK-NEXT:    ret i27 [[V]]
+; CHECK-NEXT:    ret i27 0
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
   %v = load i27, ptr %a
@@ -290,8 +287,7 @@ define i27 @load_after_memset_0_non_byte_sized(ptr %a) {
 define <4 x i8> @load_after_memset_0_vec(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_0_vec(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load <4 x i8>, ptr [[A]], align 4
-; CHECK-NEXT:    ret <4 x i8> [[V]]
+; CHECK-NEXT:    ret <4 x i8> zeroinitializer
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
   %v = load <4 x i8>, ptr %a
@@ -301,8 +297,7 @@ define <4 x i8> @load_after_memset_0_vec(ptr %a) {
 define i32 @load_after_memset_1(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_1(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load i32, ptr [[A]], align 4
-; CHECK-NEXT:    ret i32 [[V]]
+; CHECK-NEXT:    ret i32 16843009
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false)
   %v = load i32, ptr %a
@@ -312,8 +307,7 @@ define i32 @load_after_memset_1(ptr %a) {
 define float @load_after_memset_1_float(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_1_float(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load float, ptr [[A]], align 4
-; CHECK-NEXT:    ret float [[V]]
+; CHECK-NEXT:    ret float 0x3820202020000000
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false)
   %v = load float, ptr %a
@@ -323,8 +317,7 @@ define float @load_after_memset_1_float(ptr %a) {
 define i27 @load_after_memset_1_non_byte_sized(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_1_non_byte_sized(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load i27, ptr [[A]], align 4
-; CHECK-NEXT:    ret i27 [[V]]
+; CHECK-NEXT:    ret i27 16843009
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false)
   %v = load i27, ptr %a
@@ -334,8 +327,7 @@ define i27 @load_after_memset_1_non_byte_sized(ptr %a) {
 define <4 x i8> @load_after_memset_1_vec(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_1_vec(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)
-; CHECK-NEXT:    [[V:%.*]] = load <4 x i8>, ptr [[A]], align 4
-; CHECK-NEXT:    ret <4 x i8> [[V]]
+; CHECK-NEXT:    ret <4 x i8> <i8 1, i8 1, i8 1, i8 1>
 ;
   call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false)
   %v = load <4 x i8>, ptr %a
@@ -353,6 +345,7 @@ define i32 @load_after_memset_unknown(ptr %a, i8 %byte) {
   ret i32 %v
 }
 
+; TODO: Handle load at offset.
 define i32 @load_after_memset_0_offset(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_0_offset(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
@@ -416,6 +409,17 @@ define i256 @load_after_memset_0_too_small(ptr %a) {
   ret i256 %v
 }
 
+define i129 @load_after_memset_0_too_small_by_one_bit(ptr %a) {
+; CHECK-LABEL: @load_after_memset_0_too_small_by_one_bit(
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
+; CHECK-NEXT:    [[V:%.*]] = load i129, ptr [[A]], align 4
+; CHECK-NEXT:    ret i129 [[V]]
+;
+  call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
+  %v = load i129, ptr %a
+  ret i129 %v
+}
+
 define i32 @load_after_memset_0_unknown_length(ptr %a, i64 %len) {
 ; CHECK-LABEL: @load_after_memset_0_unknown_length(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 1 [[A:%.*]], i8 0, i64 [[LEN:%.*]], i1 false)


        


More information about the llvm-commits mailing list