[llvm] ed443d8 - [AggressiveInstCombine] Only fold consecutive shifts of loads with constant shift amounts

Arthur Eubanks via llvm-commits llvm-commits at lists.llvm.org
Thu May 4 13:52:36 PDT 2023


Author: Arthur Eubanks
Date: 2023-05-04T13:52:25-07:00
New Revision: ed443d81d16768786b81534094ae64fa0afa5936

URL: https://github.com/llvm/llvm-project/commit/ed443d81d16768786b81534094ae64fa0afa5936
DIFF: https://github.com/llvm/llvm-project/commit/ed443d81d16768786b81534094ae64fa0afa5936.diff

LOG: [AggressiveInstCombine] Only fold consecutive shifts of loads with constant shift amounts

This is what the code assumed but never actually checked.

Fixes https://github.com/llvm/llvm-project/issues/62509.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D149896

Added: 
    

Modified: 
    llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
    llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 0b8f853d9df23..3c53c7adb29c4 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -611,7 +611,7 @@ struct LoadOps {
   LoadInst *RootInsert = nullptr;
   bool FoundRoot = false;
   uint64_t LoadSize = 0;
-  Value *Shift = nullptr;
+  const APInt *Shift = nullptr;
   Type *ZextType;
   AAMDNodes AATags;
 };
@@ -621,7 +621,7 @@ struct LoadOps {
 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
                                AliasAnalysis &AA) {
-  Value *ShAmt2 = nullptr;
+  const APInt *ShAmt2 = nullptr;
   Value *X;
   Instruction *L1, *L2;
 
@@ -629,7 +629,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
   if (match(V, m_OneUse(m_c_Or(
                    m_Value(X),
                    m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
-                                  m_Value(ShAmt2)))))) ||
+                                  m_APInt(ShAmt2)))))) ||
       match(V, m_OneUse(m_Or(m_Value(X),
                              m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
     if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
@@ -640,11 +640,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
 
   // Check if the pattern has loads
   LoadInst *LI1 = LOps.Root;
-  Value *ShAmt1 = LOps.Shift;
+  const APInt *ShAmt1 = LOps.Shift;
   if (LOps.FoundRoot == false &&
       (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
        match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
-                               m_Value(ShAmt1)))))) {
+                               m_APInt(ShAmt1)))))) {
     LI1 = dyn_cast<LoadInst>(L1);
   }
   LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -719,12 +719,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
     std::swap(ShAmt1, ShAmt2);
 
   // Find Shifts values.
-  const APInt *Temp;
   uint64_t Shift1 = 0, Shift2 = 0;
-  if (ShAmt1 && match(ShAmt1, m_APInt(Temp)))
-    Shift1 = Temp->getZExtValue();
-  if (ShAmt2 && match(ShAmt2, m_APInt(Temp)))
-    Shift2 = Temp->getZExtValue();
+  if (ShAmt1)
+    Shift1 = ShAmt1->getZExtValue();
+  if (ShAmt2)
+    Shift2 = ShAmt2->getZExtValue();
 
   // First load is always LI1. This is where we put the new load.
   // Use the merged load size available from LI1 for forward loads.
@@ -816,7 +815,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
   // Check if shift needed. We need to shift with the amount of load1
   // shift if not zero.
   if (LOps.Shift)
-    NewOp = Builder.CreateShl(NewOp, LOps.Shift);
+    NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
   I.replaceAllUsesWith(NewOp);
 
   return true;

diff  --git a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
index de614173ec6da..842b1f781eac7 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll
@@ -2253,3 +2253,53 @@ define i32 @loadCombine_4consecutive_badinsert6(ptr %p) {
   %o3 = or i32 %o2, %e1
   ret i32 %o3
 }
+
+define i64 @loadCombine_nonConstShift1(ptr %arg, i8 %b) {
+; ALL-LABEL: @loadCombine_nonConstShift1(
+; ALL-NEXT:    [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1
+; ALL-NEXT:    [[LD0:%.*]] = load i8, ptr [[ARG]], align 1
+; ALL-NEXT:    [[LD1:%.*]] = load i8, ptr [[G1]], align 1
+; ALL-NEXT:    [[Z0:%.*]] = zext i8 [[LD0]] to i64
+; ALL-NEXT:    [[Z1:%.*]] = zext i8 [[LD1]] to i64
+; ALL-NEXT:    [[Z6:%.*]] = zext i8 [[B:%.*]] to i64
+; ALL-NEXT:    [[S0:%.*]] = shl i64 [[Z0]], [[Z6]]
+; ALL-NEXT:    [[S1:%.*]] = shl i64 [[Z1]], 8
+; ALL-NEXT:    [[O7:%.*]] = or i64 [[S0]], [[S1]]
+; ALL-NEXT:    ret i64 [[O7]]
+;
+  %g1 = getelementptr i8, ptr %arg, i64 1
+  %ld0 = load i8, ptr %arg, align 1
+  %ld1 = load i8, ptr %g1, align 1
+  %z0 = zext i8 %ld0 to i64
+  %z1 = zext i8 %ld1 to i64
+  %z6 = zext i8 %b to i64
+  %s0 = shl i64 %z0, %z6
+  %s1 = shl i64 %z1, 8
+  %o7 = or i64 %s0, %s1
+  ret i64 %o7
+}
+
+define i64 @loadCombine_nonConstShift2(ptr %arg, i8 %b) {
+; ALL-LABEL: @loadCombine_nonConstShift2(
+; ALL-NEXT:    [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1
+; ALL-NEXT:    [[LD0:%.*]] = load i8, ptr [[ARG]], align 1
+; ALL-NEXT:    [[LD1:%.*]] = load i8, ptr [[G1]], align 1
+; ALL-NEXT:    [[Z0:%.*]] = zext i8 [[LD0]] to i64
+; ALL-NEXT:    [[Z1:%.*]] = zext i8 [[LD1]] to i64
+; ALL-NEXT:    [[Z6:%.*]] = zext i8 [[B:%.*]] to i64
+; ALL-NEXT:    [[S0:%.*]] = shl i64 [[Z0]], [[Z6]]
+; ALL-NEXT:    [[S1:%.*]] = shl i64 [[Z1]], 8
+; ALL-NEXT:    [[O7:%.*]] = or i64 [[S1]], [[S0]]
+; ALL-NEXT:    ret i64 [[O7]]
+;
+  %g1 = getelementptr i8, ptr %arg, i64 1
+  %ld0 = load i8, ptr %arg, align 1
+  %ld1 = load i8, ptr %g1, align 1
+  %z0 = zext i8 %ld0 to i64
+  %z1 = zext i8 %ld1 to i64
+  %z6 = zext i8 %b to i64
+  %s0 = shl i64 %z0, %z6
+  %s1 = shl i64 %z1, 8
+  %o7 = or i64 %s1, %s0
+  ret i64 %o7
+}


        


More information about the llvm-commits mailing list