[llvm] TargetLibraryInfo: Use pointer index size to determine getSizeTSize(). (PR #118747)

Owen Anderson via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 01:26:04 PST 2024


https://github.com/resistor updated https://github.com/llvm/llvm-project/pull/118747

>From 3f0a861e079dc846593cfc45dca0af36eb946749 Mon Sep 17 00:00:00 2001
From: Owen Anderson <resistor at mac.com>
Date: Thu, 5 Dec 2024 19:06:08 +1300
Subject: [PATCH 1/2] TargetLibraryInfo: Use pointer index size to determine
 getSizeTSize().

When using non-integral pointer types, such as on CHERI targets, size_t is equivalent
to the index size, which is allowed to be smaller than the size of the pointer.
---
 llvm/lib/Analysis/TargetLibraryInfo.cpp                  | 9 +++------
 llvm/test/Transforms/InstCombine/stdio-custom-dl.ll      | 3 +--
 .../MergeICmps/X86/distinct-index-width-crash.ll         | 4 ++--
 3 files changed, 6 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
index e0482b2b1ce025..b4bd53c24eecb0 100644
--- a/llvm/lib/Analysis/TargetLibraryInfo.cpp
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -1465,13 +1465,10 @@ unsigned TargetLibraryInfoImpl::getSizeTSize(const Module &M) const {
 
   // Historically LLVM assume that size_t has same size as intptr_t (hence
   // deriving the size from sizeof(int*) in address space zero). This should
-  // work for most targets. For future consideration: DataLayout also implement
-  // getIndexSizeInBits which might map better to size_t compared to
-  // getPointerSizeInBits. Hard coding address space zero here might be
-  // unfortunate as well. Maybe getDefaultGlobalsAddressSpace() or
-  // getAllocaAddrSpace() is better.
+  // work for most targets. For future consideration: Hard coding address space
+  // zero here might be unfortunate. Maybe getMaxIndexSizeInBits() is better.
   unsigned AddressSpace = 0;
-  return M.getDataLayout().getPointerSizeInBits(AddressSpace);
+  return M.getDataLayout().getIndexSizeInBits(AddressSpace);
 }
 
 TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
diff --git a/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll b/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
index cc06be7e759d0c..44b702d8391225 100644
--- a/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
+++ b/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
@@ -8,11 +8,10 @@ target datalayout = "e-m:o-p:40:64:64:32-i64:64-f80:128-n8:16:32:64-S128"
 @.str.1 = private unnamed_addr constant [2 x i8] c"w\00", align 1
 @.str.2 = private unnamed_addr constant [4 x i8] c"str\00", align 1
 
-; Check fwrite is generated with arguments of ptr size, not index size
 define internal void @fputs_test_custom_dl() {
 ; CHECK-LABEL: @fputs_test_custom_dl(
 ; CHECK-NEXT:    [[CALL:%.*]] = call ptr @fopen(ptr nonnull @.str, ptr nonnull @.str.1)
-; CHECK-NEXT:    [[TMP1:%.*]] = call i40 @fwrite(ptr nonnull @.str.2, i40 3, i40 1, ptr [[CALL]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @fwrite(ptr nonnull @.str.2, i32 3, i32 1, ptr %call)
 ; CHECK-NEXT:    ret void
 ;
   %call = call ptr @fopen(ptr @.str, ptr @.str.1)
diff --git a/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll b/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll
index 7dce968ee9de0b..8ff7e95674f963 100644
--- a/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll
+++ b/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll
@@ -8,7 +8,7 @@ target triple = "x86_64"
 target datalayout = "e-p:64:64:64:32"
 
 ; Define a cunstom data layout that has index width < pointer width
-; and make sure that doesn't mreak anything
+; and make sure that doesn't break anything
 define void @fat_ptrs(ptr dereferenceable(16) %a, ptr dereferenceable(16) %b) {
 ; CHECK-LABEL: @fat_ptrs(
 ; CHECK-NEXT:  bb0:
@@ -16,7 +16,7 @@ define void @fat_ptrs(ptr dereferenceable(16) %a, ptr dereferenceable(16) %b) {
 ; CHECK-NEXT:    [[PTR_B1:%.*]] = getelementptr inbounds [2 x i64], ptr [[B:%.*]], i32 0, i32 1
 ; CHECK-NEXT:    br label %"bb1+bb2"
 ; CHECK:       "bb1+bb2":
-; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[A]], ptr [[B]], i64 16)
+; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[A]], ptr [[B]], i32 16)
 ; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i32 [[MEMCMP]], 0
 ; CHECK-NEXT:    br label [[BB3:%.*]]
 ; CHECK:       bb3:

>From d2fe2e09ea45139416348718d0f7564f24f8775d Mon Sep 17 00:00:00 2001
From: Owen Anderson <resistor at mac.com>
Date: Thu, 5 Dec 2024 22:20:19 +1300
Subject: [PATCH 2/2] Introduce helpers for materializing size_t values and
 propagate proper usage of it through SimplifyLibCalls.

---
 .../include/llvm/Analysis/TargetLibraryInfo.h | 17 ++++
 llvm/lib/Analysis/TargetLibraryInfo.cpp       |  9 ++
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 82 +++++++------------
 .../InstCombine/strcpy-nonzero-as.ll          |  4 +-
 4 files changed, 57 insertions(+), 55 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.h b/llvm/include/llvm/Analysis/TargetLibraryInfo.h
index 325c9cd9900b36..15c514b04810a2 100644
--- a/llvm/include/llvm/Analysis/TargetLibraryInfo.h
+++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.h
@@ -20,6 +20,7 @@
 namespace llvm {
 
 template <typename T> class ArrayRef;
+class ConstantInt;
 
 /// Provides info so a possible vectorization of a function can be
 /// computed. Function 'VectorFnName' is equivalent to 'ScalarFnName'
@@ -249,6 +250,12 @@ class TargetLibraryInfoImpl {
   /// Returns the size of the size_t type in bits.
   unsigned getSizeTSize(const Module &M) const;
 
+  /// Returns an IntegerType corresponding to size_t.
+  IntegerType *getSizeTType(const Module &M) const;
+
+  /// Returns a constant materialized as a size_t type.
+  ConstantInt *getAsSizeT(uint64_t V, const Module &M) const;
+
   /// Get size of a C-level int or unsigned int, in bits.
   unsigned getIntSize() const {
     return SizeOfInt;
@@ -565,6 +572,16 @@ class TargetLibraryInfo {
   /// \copydoc TargetLibraryInfoImpl::getSizeTSize()
   unsigned getSizeTSize(const Module &M) const { return Impl->getSizeTSize(M); }
 
+  /// \copydoc TargetLibraryInfoImpl::getSizeTType()
+  IntegerType *getSizeTType(const Module &M) const {
+    return Impl->getSizeTType(M);
+  }
+
+  /// \copydoc TargetLibraryInfoImpl::getAsSizeT()
+  ConstantInt *getAsSizeT(uint64_t V, const Module &M) const {
+    return Impl->getAsSizeT(V, M);
+  }
+
   /// \copydoc TargetLibraryInfoImpl::getIntSize()
   unsigned getIntSize() const {
     return Impl->getIntSize();
diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
index b4bd53c24eecb0..aedc4b88bf4455 100644
--- a/llvm/lib/Analysis/TargetLibraryInfo.cpp
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -1471,6 +1471,15 @@ unsigned TargetLibraryInfoImpl::getSizeTSize(const Module &M) const {
   return M.getDataLayout().getIndexSizeInBits(AddressSpace);
 }
 
+IntegerType *TargetLibraryInfoImpl::getSizeTType(const Module &M) const {
+  return IntegerType::get(M.getContext(), getSizeTSize(M));
+}
+
+ConstantInt *TargetLibraryInfoImpl::getAsSizeT(uint64_t V,
+                                               const Module &M) const {
+  return ConstantInt::get(getSizeTType(M), V);
+}
+
 TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
     : ImmutablePass(ID), TLA(TargetLibraryInfoImpl()) {
   initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry());
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index d85e0d99466022..737818b7825cf4 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -397,9 +397,8 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,
 
   // We have enough information to now generate the memcpy call to do the
   // concatenation for us.  Make a memcpy to copy the nul byte with align = 1.
-  B.CreateMemCpy(
-      CpyDst, Align(1), Src, Align(1),
-      ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1));
+  B.CreateMemCpy(CpyDst, Align(1), Src, Align(1),
+                 TLI->getAsSizeT(Len + 1, *B.GetInsertBlock()->getModule()));
   return Dst;
 }
 
@@ -590,26 +589,21 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
   if (Len1 && Len2) {
     return copyFlags(
         *CI, emitMemCmp(Str1P, Str2P,
-                        ConstantInt::get(DL.getIntPtrType(CI->getContext()),
-                                         std::min(Len1, Len2)),
+                        TLI->getAsSizeT(std::min(Len1, Len2), *CI->getModule()),
                         B, DL, TLI));
   }
 
   // strcmp to memcmp
   if (!HasStr1 && HasStr2) {
     if (canTransformToMemCmp(CI, Str1P, Len2, DL))
-      return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
-                     B, DL, TLI));
+      return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+                                       TLI->getAsSizeT(Len2, *CI->getModule()),
+                                       B, DL, TLI));
   } else if (HasStr1 && !HasStr2) {
     if (canTransformToMemCmp(CI, Str2P, Len1, DL))
-      return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
-                     B, DL, TLI));
+      return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+                                       TLI->getAsSizeT(Len1, *CI->getModule()),
+                                       B, DL, TLI));
   }
 
   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
@@ -676,19 +670,15 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
   if (!HasStr1 && HasStr2) {
     Len2 = std::min(Len2, Length);
     if (canTransformToMemCmp(CI, Str1P, Len2, DL))
-      return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
-                     B, DL, TLI));
+      return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+                                       TLI->getAsSizeT(Len2, *CI->getModule()),
+                                       B, DL, TLI));
   } else if (HasStr1 && !HasStr2) {
     Len1 = std::min(Len1, Length);
     if (canTransformToMemCmp(CI, Str2P, Len1, DL))
-      return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
-                     B, DL, TLI));
+      return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+                                       TLI->getAsSizeT(Len1, *CI->getModule()),
+                                       B, DL, TLI));
   }
 
   return nullptr;
@@ -722,15 +712,13 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {
 
   // We have enough information to now generate the memcpy call to do the
   // copy for us.  Make a memcpy to copy the nul byte with align = 1.
-  CallInst *NewCI =
-      B.CreateMemCpy(Dst, Align(1), Src, Align(1),
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
+  CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
+                                   TLI->getAsSizeT(Len, *CI->getModule()));
   mergeAttributesAndFlags(NewCI, *CI);
   return Dst;
 }
 
 Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
-  Function *Callee = CI->getCalledFunction();
   Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
 
   // stpcpy(d,s) -> strcpy(d,s) if the result is not used.
@@ -749,10 +737,9 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
   else
     return nullptr;
 
-  Type *PT = Callee->getFunctionType()->getParamType(0);
-  Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len);
+  Value *LenV = TLI->getAsSizeT(Len, *CI->getModule());
   Value *DstEnd = B.CreateInBoundsGEP(
-      B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
+      B.getInt8Ty(), Dst, TLI->getAsSizeT(Len - 1, *CI->getModule()));
 
   // We have enough information to now generate the memcpy call to do the
   // copy for us.  Make a memcpy to copy the nul byte with align = 1.
@@ -819,13 +806,11 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
     return ConstantInt::get(CI->getType(), 0);
   }
 
-  Function *Callee = CI->getCalledFunction();
-  Type *PT = Callee->getFunctionType()->getParamType(0);
   // Transform strlcpy(D, S, N) to memcpy(D, S, N') where N' is the lower
   // bound on strlen(S) + 1 and N, optionally followed by a nul store to
   // D[N' - 1] if necessary.
   CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
-                        ConstantInt::get(DL.getIntPtrType(PT), NBytes));
+                                   TLI->getAsSizeT(NBytes, *CI->getModule()));
   mergeAttributesAndFlags(NewCI, *CI);
 
   if (!NulTerm) {
@@ -844,7 +829,6 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
 // otherwise.
 Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
                                              IRBuilderBase &B) {
-  Function *Callee = CI->getCalledFunction();
   Value *Dst = CI->getArgOperand(0);
   Value *Src = CI->getArgOperand(1);
   Value *Size = CI->getArgOperand(2);
@@ -921,11 +905,10 @@ Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
                                /*M=*/nullptr, /*AddNull=*/false);
   }
 
-  Type *PT = Callee->getFunctionType()->getParamType(0);
   // st{p,r}ncpy(D, S, N) -> memcpy(align 1 D, align 1 S, N) when both
   // S and N are constant.
   CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
-                                   ConstantInt::get(DL.getIntPtrType(PT), N));
+                                   TLI->getAsSizeT(N, *CI->getModule()));
   mergeAttributesAndFlags(NewCI, *CI);
   if (!RetEnd)
     return Dst;
@@ -3432,10 +3415,9 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
       return nullptr; // we found a format specifier, bail out.
 
     // sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
-    B.CreateMemCpy(
-        Dest, Align(1), CI->getArgOperand(1), Align(1),
-        ConstantInt::get(DL.getIntPtrType(CI->getContext()),
-                         FormatStr.size() + 1)); // Copy the null byte.
+    B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(1), Align(1),
+                   // Copy the null byte.
+                   TLI->getAsSizeT(FormatStr.size() + 1, *CI->getModule()));
     return ConstantInt::get(CI->getType(), FormatStr.size());
   }
 
@@ -3470,9 +3452,8 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
 
     uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
     if (SrcLen) {
-      B.CreateMemCpy(
-          Dest, Align(1), CI->getArgOperand(2), Align(1),
-          ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen));
+      B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(2), Align(1),
+                     TLI->getAsSizeT(SrcLen, *CI->getModule()));
       // Returns total number of characters written without null-character.
       return ConstantInt::get(CI->getType(), SrcLen - 1);
     } else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
@@ -3570,11 +3551,8 @@ Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg,
   Value *DstArg = CI->getArgOperand(0);
   if (NCopy && StrArg)
     // Transform the call to lvm.memcpy(dst, fmt, N).
-    copyFlags(
-         *CI,
-          B.CreateMemCpy(
-                         DstArg, Align(1), StrArg, Align(1),
-              ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy)));
+    copyFlags(*CI, B.CreateMemCpy(DstArg, Align(1), StrArg, Align(1),
+                                  TLI->getAsSizeT(NCopy, *CI->getModule())));
 
   if (N > Str.size())
     // Return early when the whole format string, including the final nul,
@@ -3690,11 +3668,9 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI,
     if (FormatStr.contains('%'))
       return nullptr; // We found a format specifier.
 
-    unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
-    Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
     return copyFlags(
         *CI, emitFWrite(CI->getArgOperand(1),
-                        ConstantInt::get(SizeTTy, FormatStr.size()),
+                        TLI->getAsSizeT(FormatStr.size(), *CI->getModule()),
                         CI->getArgOperand(0), B, DL, TLI));
   }
 
diff --git a/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll b/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
index 86b49ffdf04b2b..9bde0a3ac3fde0 100644
--- a/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
+++ b/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
@@ -52,7 +52,7 @@ define void @test_strncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
 ; CHECK-LABEL: define {{[^@]+}}@test_strncpy_to_memcpy
 ; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
+; CHECK-NEXT:    call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -64,7 +64,7 @@ define void @test_stpncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
 ; CHECK-LABEL: define {{[^@]+}}@test_stpncpy_to_memcpy
 ; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
+; CHECK-NEXT:    call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
 ; CHECK-NEXT:    ret void
 ;
 entry:



More information about the llvm-commits mailing list