[llvm] TargetLibraryInfo: Use pointer index size to determine getSizeTSize(). (PR #118747)
Owen Anderson via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 5 01:26:04 PST 2024
https://github.com/resistor updated https://github.com/llvm/llvm-project/pull/118747
>From 3f0a861e079dc846593cfc45dca0af36eb946749 Mon Sep 17 00:00:00 2001
From: Owen Anderson <resistor at mac.com>
Date: Thu, 5 Dec 2024 19:06:08 +1300
Subject: [PATCH 1/2] TargetLibraryInfo: Use pointer index size to determine
getSizeTSize().
When using non-integral pointer types, such as on CHERI targets, size_t is equivalent
to the index size, which is allowed to be smaller than the size of the pointer.
---
llvm/lib/Analysis/TargetLibraryInfo.cpp | 9 +++------
llvm/test/Transforms/InstCombine/stdio-custom-dl.ll | 3 +--
.../MergeICmps/X86/distinct-index-width-crash.ll | 4 ++--
3 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
index e0482b2b1ce025..b4bd53c24eecb0 100644
--- a/llvm/lib/Analysis/TargetLibraryInfo.cpp
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -1465,13 +1465,10 @@ unsigned TargetLibraryInfoImpl::getSizeTSize(const Module &M) const {
// Historically LLVM assume that size_t has same size as intptr_t (hence
// deriving the size from sizeof(int*) in address space zero). This should
- // work for most targets. For future consideration: DataLayout also implement
- // getIndexSizeInBits which might map better to size_t compared to
- // getPointerSizeInBits. Hard coding address space zero here might be
- // unfortunate as well. Maybe getDefaultGlobalsAddressSpace() or
- // getAllocaAddrSpace() is better.
+ // work for most targets. For future consideration: Hard coding address space
+ // zero here might be unfortunate. Maybe getMaxIndexSizeInBits() is better.
unsigned AddressSpace = 0;
- return M.getDataLayout().getPointerSizeInBits(AddressSpace);
+ return M.getDataLayout().getIndexSizeInBits(AddressSpace);
}
TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
diff --git a/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll b/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
index cc06be7e759d0c..44b702d8391225 100644
--- a/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
+++ b/llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
@@ -8,11 +8,10 @@ target datalayout = "e-m:o-p:40:64:64:32-i64:64-f80:128-n8:16:32:64-S128"
@.str.1 = private unnamed_addr constant [2 x i8] c"w\00", align 1
@.str.2 = private unnamed_addr constant [4 x i8] c"str\00", align 1
-; Check fwrite is generated with arguments of ptr size, not index size
define internal void @fputs_test_custom_dl() {
; CHECK-LABEL: @fputs_test_custom_dl(
; CHECK-NEXT: [[CALL:%.*]] = call ptr @fopen(ptr nonnull @.str, ptr nonnull @.str.1)
-; CHECK-NEXT: [[TMP1:%.*]] = call i40 @fwrite(ptr nonnull @.str.2, i40 3, i40 1, ptr [[CALL]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @fwrite(ptr nonnull @.str.2, i32 3, i32 1, ptr %call)
; CHECK-NEXT: ret void
;
%call = call ptr @fopen(ptr @.str, ptr @.str.1)
diff --git a/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll b/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll
index 7dce968ee9de0b..8ff7e95674f963 100644
--- a/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll
+++ b/llvm/test/Transforms/MergeICmps/X86/distinct-index-width-crash.ll
@@ -8,7 +8,7 @@ target triple = "x86_64"
target datalayout = "e-p:64:64:64:32"
; Define a cunstom data layout that has index width < pointer width
-; and make sure that doesn't mreak anything
+; and make sure that doesn't break anything
define void @fat_ptrs(ptr dereferenceable(16) %a, ptr dereferenceable(16) %b) {
; CHECK-LABEL: @fat_ptrs(
; CHECK-NEXT: bb0:
@@ -16,7 +16,7 @@ define void @fat_ptrs(ptr dereferenceable(16) %a, ptr dereferenceable(16) %b) {
; CHECK-NEXT: [[PTR_B1:%.*]] = getelementptr inbounds [2 x i64], ptr [[B:%.*]], i32 0, i32 1
; CHECK-NEXT: br label %"bb1+bb2"
; CHECK: "bb1+bb2":
-; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[A]], ptr [[B]], i64 16)
+; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[A]], ptr [[B]], i32 16)
; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[MEMCMP]], 0
; CHECK-NEXT: br label [[BB3:%.*]]
; CHECK: bb3:
>From d2fe2e09ea45139416348718d0f7564f24f8775d Mon Sep 17 00:00:00 2001
From: Owen Anderson <resistor at mac.com>
Date: Thu, 5 Dec 2024 22:20:19 +1300
Subject: [PATCH 2/2] Introduce helpers for materializing size_t values and
propagate proper usage of it through SimplifyLibCalls.
---
.../include/llvm/Analysis/TargetLibraryInfo.h | 17 ++++
llvm/lib/Analysis/TargetLibraryInfo.cpp | 9 ++
.../lib/Transforms/Utils/SimplifyLibCalls.cpp | 82 +++++++------------
.../InstCombine/strcpy-nonzero-as.ll | 4 +-
4 files changed, 57 insertions(+), 55 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.h b/llvm/include/llvm/Analysis/TargetLibraryInfo.h
index 325c9cd9900b36..15c514b04810a2 100644
--- a/llvm/include/llvm/Analysis/TargetLibraryInfo.h
+++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.h
@@ -20,6 +20,7 @@
namespace llvm {
template <typename T> class ArrayRef;
+class ConstantInt;
/// Provides info so a possible vectorization of a function can be
/// computed. Function 'VectorFnName' is equivalent to 'ScalarFnName'
@@ -249,6 +250,12 @@ class TargetLibraryInfoImpl {
/// Returns the size of the size_t type in bits.
unsigned getSizeTSize(const Module &M) const;
+ /// Returns an IntegerType corresponding to size_t.
+ IntegerType *getSizeTType(const Module &M) const;
+
+ /// Returns a constant materialized as a size_t type.
+ ConstantInt *getAsSizeT(uint64_t V, const Module &M) const;
+
/// Get size of a C-level int or unsigned int, in bits.
unsigned getIntSize() const {
return SizeOfInt;
@@ -565,6 +572,16 @@ class TargetLibraryInfo {
/// \copydoc TargetLibraryInfoImpl::getSizeTSize()
unsigned getSizeTSize(const Module &M) const { return Impl->getSizeTSize(M); }
+ /// \copydoc TargetLibraryInfoImpl::getSizeTType()
+ IntegerType *getSizeTType(const Module &M) const {
+ return Impl->getSizeTType(M);
+ }
+
+ /// \copydoc TargetLibraryInfoImpl::getAsSizeT()
+ ConstantInt *getAsSizeT(uint64_t V, const Module &M) const {
+ return Impl->getAsSizeT(V, M);
+ }
+
/// \copydoc TargetLibraryInfoImpl::getIntSize()
unsigned getIntSize() const {
return Impl->getIntSize();
diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
index b4bd53c24eecb0..aedc4b88bf4455 100644
--- a/llvm/lib/Analysis/TargetLibraryInfo.cpp
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -1471,6 +1471,15 @@ unsigned TargetLibraryInfoImpl::getSizeTSize(const Module &M) const {
return M.getDataLayout().getIndexSizeInBits(AddressSpace);
}
+IntegerType *TargetLibraryInfoImpl::getSizeTType(const Module &M) const {
+ return IntegerType::get(M.getContext(), getSizeTSize(M));
+}
+
+ConstantInt *TargetLibraryInfoImpl::getAsSizeT(uint64_t V,
+ const Module &M) const {
+ return ConstantInt::get(getSizeTType(M), V);
+}
+
TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
: ImmutablePass(ID), TLA(TargetLibraryInfoImpl()) {
initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry());
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index d85e0d99466022..737818b7825cf4 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -397,9 +397,8 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,
// We have enough information to now generate the memcpy call to do the
// concatenation for us. Make a memcpy to copy the nul byte with align = 1.
- B.CreateMemCpy(
- CpyDst, Align(1), Src, Align(1),
- ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1));
+ B.CreateMemCpy(CpyDst, Align(1), Src, Align(1),
+ TLI->getAsSizeT(Len + 1, *B.GetInsertBlock()->getModule()));
return Dst;
}
@@ -590,26 +589,21 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
if (Len1 && Len2) {
return copyFlags(
*CI, emitMemCmp(Str1P, Str2P,
- ConstantInt::get(DL.getIntPtrType(CI->getContext()),
- std::min(Len1, Len2)),
+ TLI->getAsSizeT(std::min(Len1, Len2), *CI->getModule()),
B, DL, TLI));
}
// strcmp to memcmp
if (!HasStr1 && HasStr2) {
if (canTransformToMemCmp(CI, Str1P, Len2, DL))
- return copyFlags(
- *CI,
- emitMemCmp(Str1P, Str2P,
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
- B, DL, TLI));
+ return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+ TLI->getAsSizeT(Len2, *CI->getModule()),
+ B, DL, TLI));
} else if (HasStr1 && !HasStr2) {
if (canTransformToMemCmp(CI, Str2P, Len1, DL))
- return copyFlags(
- *CI,
- emitMemCmp(Str1P, Str2P,
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
- B, DL, TLI));
+ return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+ TLI->getAsSizeT(Len1, *CI->getModule()),
+ B, DL, TLI));
}
annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
@@ -676,19 +670,15 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
if (!HasStr1 && HasStr2) {
Len2 = std::min(Len2, Length);
if (canTransformToMemCmp(CI, Str1P, Len2, DL))
- return copyFlags(
- *CI,
- emitMemCmp(Str1P, Str2P,
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
- B, DL, TLI));
+ return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+ TLI->getAsSizeT(Len2, *CI->getModule()),
+ B, DL, TLI));
} else if (HasStr1 && !HasStr2) {
Len1 = std::min(Len1, Length);
if (canTransformToMemCmp(CI, Str2P, Len1, DL))
- return copyFlags(
- *CI,
- emitMemCmp(Str1P, Str2P,
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
- B, DL, TLI));
+ return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
+ TLI->getAsSizeT(Len1, *CI->getModule()),
+ B, DL, TLI));
}
return nullptr;
@@ -722,15 +712,13 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {
// We have enough information to now generate the memcpy call to do the
// copy for us. Make a memcpy to copy the nul byte with align = 1.
- CallInst *NewCI =
- B.CreateMemCpy(Dst, Align(1), Src, Align(1),
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
+ CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
+ TLI->getAsSizeT(Len, *CI->getModule()));
mergeAttributesAndFlags(NewCI, *CI);
return Dst;
}
Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
- Function *Callee = CI->getCalledFunction();
Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
// stpcpy(d,s) -> strcpy(d,s) if the result is not used.
@@ -749,10 +737,9 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
else
return nullptr;
- Type *PT = Callee->getFunctionType()->getParamType(0);
- Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len);
+ Value *LenV = TLI->getAsSizeT(Len, *CI->getModule());
Value *DstEnd = B.CreateInBoundsGEP(
- B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
+ B.getInt8Ty(), Dst, TLI->getAsSizeT(Len - 1, *CI->getModule()));
// We have enough information to now generate the memcpy call to do the
// copy for us. Make a memcpy to copy the nul byte with align = 1.
@@ -819,13 +806,11 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
return ConstantInt::get(CI->getType(), 0);
}
- Function *Callee = CI->getCalledFunction();
- Type *PT = Callee->getFunctionType()->getParamType(0);
// Transform strlcpy(D, S, N) to memcpy(D, S, N') where N' is the lower
// bound on strlen(S) + 1 and N, optionally followed by a nul store to
// D[N' - 1] if necessary.
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
- ConstantInt::get(DL.getIntPtrType(PT), NBytes));
+ TLI->getAsSizeT(NBytes, *CI->getModule()));
mergeAttributesAndFlags(NewCI, *CI);
if (!NulTerm) {
@@ -844,7 +829,6 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
// otherwise.
Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
IRBuilderBase &B) {
- Function *Callee = CI->getCalledFunction();
Value *Dst = CI->getArgOperand(0);
Value *Src = CI->getArgOperand(1);
Value *Size = CI->getArgOperand(2);
@@ -921,11 +905,10 @@ Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
/*M=*/nullptr, /*AddNull=*/false);
}
- Type *PT = Callee->getFunctionType()->getParamType(0);
// st{p,r}ncpy(D, S, N) -> memcpy(align 1 D, align 1 S, N) when both
// S and N are constant.
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
- ConstantInt::get(DL.getIntPtrType(PT), N));
+ TLI->getAsSizeT(N, *CI->getModule()));
mergeAttributesAndFlags(NewCI, *CI);
if (!RetEnd)
return Dst;
@@ -3432,10 +3415,9 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
return nullptr; // we found a format specifier, bail out.
// sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
- B.CreateMemCpy(
- Dest, Align(1), CI->getArgOperand(1), Align(1),
- ConstantInt::get(DL.getIntPtrType(CI->getContext()),
- FormatStr.size() + 1)); // Copy the null byte.
+ B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(1), Align(1),
+ // Copy the null byte.
+ TLI->getAsSizeT(FormatStr.size() + 1, *CI->getModule()));
return ConstantInt::get(CI->getType(), FormatStr.size());
}
@@ -3470,9 +3452,8 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
if (SrcLen) {
- B.CreateMemCpy(
- Dest, Align(1), CI->getArgOperand(2), Align(1),
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen));
+ B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(2), Align(1),
+ TLI->getAsSizeT(SrcLen, *CI->getModule()));
// Returns total number of characters written without null-character.
return ConstantInt::get(CI->getType(), SrcLen - 1);
} else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
@@ -3570,11 +3551,8 @@ Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg,
Value *DstArg = CI->getArgOperand(0);
if (NCopy && StrArg)
// Transform the call to lvm.memcpy(dst, fmt, N).
- copyFlags(
- *CI,
- B.CreateMemCpy(
- DstArg, Align(1), StrArg, Align(1),
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy)));
+ copyFlags(*CI, B.CreateMemCpy(DstArg, Align(1), StrArg, Align(1),
+ TLI->getAsSizeT(NCopy, *CI->getModule())));
if (N > Str.size())
// Return early when the whole format string, including the final nul,
@@ -3690,11 +3668,9 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI,
if (FormatStr.contains('%'))
return nullptr; // We found a format specifier.
- unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
- Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
return copyFlags(
*CI, emitFWrite(CI->getArgOperand(1),
- ConstantInt::get(SizeTTy, FormatStr.size()),
+ TLI->getAsSizeT(FormatStr.size(), *CI->getModule()),
CI->getArgOperand(0), B, DL, TLI));
}
diff --git a/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll b/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
index 86b49ffdf04b2b..9bde0a3ac3fde0 100644
--- a/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
+++ b/llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
@@ -52,7 +52,7 @@ define void @test_strncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
; CHECK-LABEL: define {{[^@]+}}@test_strncpy_to_memcpy
; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
+; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
; CHECK-NEXT: ret void
;
entry:
@@ -64,7 +64,7 @@ define void @test_stpncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
; CHECK-LABEL: define {{[^@]+}}@test_stpncpy_to_memcpy
; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
+; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
; CHECK-NEXT: ret void
;
entry:
More information about the llvm-commits
mailing list