[llvm] ed443d8 - [AggressiveInstCombine] Only fold consecutive shifts of loads with constant shift amounts
Arthur Eubanks via llvm-commits
llvm-commits at lists.llvm.org
Thu May 4 13:52:36 PDT 2023
Author: Arthur Eubanks
Date: 2023-05-04T13:52:25-07:00
New Revision: ed443d81d16768786b81534094ae64fa0afa5936
URL: https://github.com/llvm/llvm-project/commit/ed443d81d16768786b81534094ae64fa0afa5936
DIFF: https://github.com/llvm/llvm-project/commit/ed443d81d16768786b81534094ae64fa0afa5936.diff
LOG: [AggressiveInstCombine] Only fold consecutive shifts of loads with constant shift amounts
This is what the code assumed but never actually checked.
Fixes https://github.com/llvm/llvm-project/issues/62509.
Reviewed By: nikic
Differential Revision: https://reviews.llvm.org/D149896
Added:
Modified:
llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 0b8f853d9df23..3c53c7adb29c4 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -611,7 +611,7 @@ struct LoadOps {
LoadInst *RootInsert = nullptr;
bool FoundRoot = false;
uint64_t LoadSize = 0;
- Value *Shift = nullptr;
+ const APInt *Shift = nullptr;
Type *ZextType;
AAMDNodes AATags;
};
@@ -621,7 +621,7 @@ struct LoadOps {
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
AliasAnalysis &AA) {
- Value *ShAmt2 = nullptr;
+ const APInt *ShAmt2 = nullptr;
Value *X;
Instruction *L1, *L2;
@@ -629,7 +629,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
if (match(V, m_OneUse(m_c_Or(
m_Value(X),
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
- m_Value(ShAmt2)))))) ||
+ m_APInt(ShAmt2)))))) ||
match(V, m_OneUse(m_Or(m_Value(X),
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
@@ -640,11 +640,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
// Check if the pattern has loads
LoadInst *LI1 = LOps.Root;
- Value *ShAmt1 = LOps.Shift;
+ const APInt *ShAmt1 = LOps.Shift;
if (LOps.FoundRoot == false &&
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
- m_Value(ShAmt1)))))) {
+ m_APInt(ShAmt1)))))) {
LI1 = dyn_cast<LoadInst>(L1);
}
LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -719,12 +719,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
std::swap(ShAmt1, ShAmt2);
// Find Shifts values.
- const APInt *Temp;
uint64_t Shift1 = 0, Shift2 = 0;
- if (ShAmt1 && match(ShAmt1, m_APInt(Temp)))
- Shift1 = Temp->getZExtValue();
- if (ShAmt2 && match(ShAmt2, m_APInt(Temp)))
- Shift2 = Temp->getZExtValue();
+ if (ShAmt1)
+ Shift1 = ShAmt1->getZExtValue();
+ if (ShAmt2)
+ Shift2 = ShAmt2->getZExtValue();
// First load is always LI1. This is where we put the new load.
// Use the merged load size available from LI1 for forward loads.
@@ -816,7 +815,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
// Check if shift needed. We need to shift with the amount of load1
// shift if not zero.
if (LOps.Shift)
- NewOp = Builder.CreateShl(NewOp, LOps.Shift);
+ NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
I.replaceAllUsesWith(NewOp);
return true;
diff --git a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
index de614173ec6da..842b1f781eac7 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
@@ -2253,3 +2253,53 @@ define i32 @loadCombine_4consecutive_badinsert6(ptr %p) {
%o3 = or i32 %o2, %e1
ret i32 %o3
}
+
+define i64 @loadCombine_nonConstShift1(ptr %arg, i8 %b) {
+; ALL-LABEL: @loadCombine_nonConstShift1(
+; ALL-NEXT: [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1
+; ALL-NEXT: [[LD0:%.*]] = load i8, ptr [[ARG]], align 1
+; ALL-NEXT: [[LD1:%.*]] = load i8, ptr [[G1]], align 1
+; ALL-NEXT: [[Z0:%.*]] = zext i8 [[LD0]] to i64
+; ALL-NEXT: [[Z1:%.*]] = zext i8 [[LD1]] to i64
+; ALL-NEXT: [[Z6:%.*]] = zext i8 [[B:%.*]] to i64
+; ALL-NEXT: [[S0:%.*]] = shl i64 [[Z0]], [[Z6]]
+; ALL-NEXT: [[S1:%.*]] = shl i64 [[Z1]], 8
+; ALL-NEXT: [[O7:%.*]] = or i64 [[S0]], [[S1]]
+; ALL-NEXT: ret i64 [[O7]]
+;
+ %g1 = getelementptr i8, ptr %arg, i64 1
+ %ld0 = load i8, ptr %arg, align 1
+ %ld1 = load i8, ptr %g1, align 1
+ %z0 = zext i8 %ld0 to i64
+ %z1 = zext i8 %ld1 to i64
+ %z6 = zext i8 %b to i64
+ %s0 = shl i64 %z0, %z6
+ %s1 = shl i64 %z1, 8
+ %o7 = or i64 %s0, %s1
+ ret i64 %o7
+}
+
+define i64 @loadCombine_nonConstShift2(ptr %arg, i8 %b) {
+; ALL-LABEL: @loadCombine_nonConstShift2(
+; ALL-NEXT: [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1
+; ALL-NEXT: [[LD0:%.*]] = load i8, ptr [[ARG]], align 1
+; ALL-NEXT: [[LD1:%.*]] = load i8, ptr [[G1]], align 1
+; ALL-NEXT: [[Z0:%.*]] = zext i8 [[LD0]] to i64
+; ALL-NEXT: [[Z1:%.*]] = zext i8 [[LD1]] to i64
+; ALL-NEXT: [[Z6:%.*]] = zext i8 [[B:%.*]] to i64
+; ALL-NEXT: [[S0:%.*]] = shl i64 [[Z0]], [[Z6]]
+; ALL-NEXT: [[S1:%.*]] = shl i64 [[Z1]], 8
+; ALL-NEXT: [[O7:%.*]] = or i64 [[S1]], [[S0]]
+; ALL-NEXT: ret i64 [[O7]]
+;
+ %g1 = getelementptr i8, ptr %arg, i64 1
+ %ld0 = load i8, ptr %arg, align 1
+ %ld1 = load i8, ptr %g1, align 1
+ %z0 = zext i8 %ld0 to i64
+ %z1 = zext i8 %ld1 to i64
+ %z6 = zext i8 %b to i64
+ %s0 = shl i64 %z0, %z6
+ %s1 = shl i64 %z1, 8
+ %o7 = or i64 %s1, %s0
+ ret i64 %o7
+}
More information about the llvm-commits
mailing list