[llvm] [InstCombine] Do not simplify lshr/shl arg if it is part of fshl rotate pattern. (PR #73441)

via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 26 03:24:18 PST 2023


https://github.com/quic-eikansh updated https://github.com/llvm/llvm-project/pull/73441

>From 6bbc77677d6771bef299d6a9ae0a66a6784b97e9 Mon Sep 17 00:00:00 2001
From: Eikansh Gupta <quic_eikagupt at quicinc.com>
Date: Sat, 25 Nov 2023 06:14:59 -0800
Subject: [PATCH 1/4] [InstCombine] Refactoring matchFunnelShift (NFC)

The matchFunnelShift function was doing pattern matching and creating
the fshl/fshr instruction if needed. Moved the pattern matching code to function convertShlOrLShrToFShlOrFShr. It can be reused for other optimizations.
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 36 ++++++++++++-------
 .../InstCombine/InstCombineInternal.h         |  3 ++
 2 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 02881109f17d29f..fd4b416ec87922f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2706,9 +2706,8 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I,
   return LastInst;
 }
 
-/// Match UB-safe variants of the funnel shift intrinsic.
-static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
-                                     const DominatorTree &DT) {
+std::optional<std::tuple<Intrinsic::ID, SmallVector<Value *, 3>>>
+InstCombinerImpl::convertShlOrLShrToFShlOrFShr(Instruction &Or) {
   // TODO: Can we reduce the code duplication between this and the related
   // rotate matching code under visitSelect and visitTrunc?
   unsigned Width = Or.getType()->getScalarSizeInBits();
@@ -2716,7 +2715,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
   Instruction *Or0, *Or1;
   if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
       !match(Or.getOperand(1), m_Instruction(Or1)))
-    return nullptr;
+    return std::nullopt;
 
   bool IsFshl = true; // Sub on LSHR.
   SmallVector<Value *, 3> FShiftArgs;
@@ -2730,7 +2729,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
         !match(Or1,
                m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
         Or0->getOpcode() == Or1->getOpcode())
-      return nullptr;
+      return std::nullopt;
 
     // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
     if (Or0->getOpcode() == BinaryOperator::LShr) {
@@ -2766,7 +2765,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
       // might remove it after this fold). This still doesn't guarantee that the
       // final codegen will match this original pattern.
       if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
-        KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+        KnownBits KnownL = computeKnownBits(L, /*Depth*/ 0, &Or);
         return KnownL.getMaxValue().ult(Width) ? L : nullptr;
       }
 
@@ -2810,7 +2809,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
       IsFshl = false; // Sub on SHL.
     }
     if (!ShAmt)
-      return nullptr;
+      return std::nullopt;
 
     FShiftArgs = {ShVal0, ShVal1, ShAmt};
   } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) {
@@ -2832,18 +2831,18 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
     const APInt *ZextHighShlAmt;
     if (!match(Or0,
                m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt)))))
-      return nullptr;
+      return std::nullopt;
 
     if (!match(Or1, m_ZExt(m_Value(Low))) ||
         !match(ZextHigh, m_ZExt(m_Value(High))))
-      return nullptr;
+      return std::nullopt;
 
     unsigned HighSize = High->getType()->getScalarSizeInBits();
     unsigned LowSize = Low->getType()->getScalarSizeInBits();
     // Make sure High does not overlap with Low and most significant bits of
     // High aren't shifted out.
     if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize))
-      return nullptr;
+      return std::nullopt;
 
     for (User *U : ZextHigh->users()) {
       Value *X, *Y;
@@ -2874,11 +2873,22 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
   }
 
   if (FShiftArgs.empty())
-    return nullptr;
+    return std::nullopt;
 
   Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
-  Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
-  return CallInst::Create(F, FShiftArgs);
+  return std::make_tuple(IID, FShiftArgs);
+}
+
+/// Match UB-safe variants of the funnel shift intrinsic.
+static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
+                                     const DominatorTree &DT) {
+  if (auto Opt = IC.convertShlOrLShrToFShlOrFShr(Or)) {
+    auto [IID, FShiftArgs] = *Opt;
+    Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
+    return CallInst::Create(F, FShiftArgs);
+  }
+
+  return nullptr;
 }
 
 /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 0bbb22be71569f6..92cdb9cbda113a5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -236,6 +236,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
     return getLosslessTrunc(C, TruncTy, Instruction::SExt);
   }
 
+  std::optional<std::tuple<Intrinsic::ID, SmallVector<Value *, 3>>>
+  convertShlOrLShrToFShlOrFShr(Instruction &Or); 
+
 private:
   bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
   bool isDesirableIntType(unsigned BitWidth) const;

>From 1f7dec385fedcdf1f4de172d6fdb8c4f1b20973e Mon Sep 17 00:00:00 2001
From: Eikansh Gupta <quic_eikagupt at quicinc.com>
Date: Sat, 25 Nov 2023 07:58:19 -0800
Subject: [PATCH 2/4] Update formatting in InstCombineInternal.h

---
 llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 92cdb9cbda113a5..303d02cc24fc9d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -237,7 +237,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   }
 
   std::optional<std::tuple<Intrinsic::ID, SmallVector<Value *, 3>>>
-  convertShlOrLShrToFShlOrFShr(Instruction &Or); 
+  convertShlOrLShrToFShlOrFShr(Instruction &Or);
 
 private:
   bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);

>From 70ed494d687ed0e0b4c8d2d8fcec59450e13cfb9 Mon Sep 17 00:00:00 2001
From: Eikansh Gupta <quic_eikagupt at quicinc.com>
Date: Sun, 26 Nov 2023 02:53:32 -0800
Subject: [PATCH 3/4] [InstCombine] Do not simplify lshr/shl arg if it is part
 of fshl rotate pattern.

The fshl/fshr having first two arguments as same gets lowered to targets
specific rotate. But based on the uses, one of the arguments can get
simplified resulting in different arguments performing equivalent operation.

This patch prevents the simplification of the arguments of lshr/shl if they are
part of fshl pattern
---
 .../InstCombineSimplifyDemanded.cpp           | 26 +++++++++++++++++++
 1 file changed, 26 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index fa076098d63cde5..518fc84a6cca013 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -610,6 +610,19 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                                     DemandedMask, Known))
             return R;
 
+      // Do not simplify if shl is part of fshl rotate pattern
+      if (I->hasOneUse()) {
+        auto *Inst = dyn_cast<Instruction>(I->user_back());
+        if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
+          if (auto Opt = convertShlOrLShrToFShlOrFShr(*Inst)) {
+            auto [IID, FShiftArgs] = *Opt;
+            if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
+                FShiftArgs[0] == FShiftArgs[1])
+              return nullptr;
+          }
+        }
+      }
+
       // TODO: If we only want bits that already match the signbit then we don't
       // need to shift.
 
@@ -670,6 +683,19 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     if (match(I->getOperand(1), m_APInt(SA))) {
       uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
 
+      // Do not simplify if lshr is part of fshl rotate pattern
+      if (I->hasOneUse()) {
+        auto *Inst = dyn_cast<Instruction>(I->user_back());
+        if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
+          if (auto Opt = convertShlOrLShrToFShlOrFShr(*Inst)) {
+            auto [IID, FShiftArgs] = *Opt;
+            if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
+                FShiftArgs[0] == FShiftArgs[1])
+              return nullptr;
+          }
+        }
+      }
+
       // If we are just demanding the shifted sign bit and below, then this can
       // be treated as an ASHR in disguise.
       if (DemandedMask.countl_zero() >= ShiftAmt) {

>From 524476246ad48547d3ffcd85ccf2d50a85f674c6 Mon Sep 17 00:00:00 2001
From: Eikansh Gupta <quic_eikagupt at quicinc.com>
Date: Sun, 26 Nov 2023 03:23:10 -0800
Subject: [PATCH 4/4] [Instcombine] Added tests

---
 llvm/test/Transforms/InstCombine/fsh.ll | 128 ++++++++++++++++++++++++
 1 file changed, 128 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/fsh.ll b/llvm/test/Transforms/InstCombine/fsh.ll
index 48bf296993f6ac2..9fc19fa5f682816 100644
--- a/llvm/test/Transforms/InstCombine/fsh.ll
+++ b/llvm/test/Transforms/InstCombine/fsh.ll
@@ -722,6 +722,134 @@ define i32 @fsh_orconst_rotate(i32 %a) {
   ret i32 %t2
 }
 
+define i32 @fsh_rotate_5(i8 %x, i32 %y) {
+; CHECK-LABEL: @fsh_rotate_5(
+; CHECK-NEXT:    [[T1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[OR1:%.*]] = or i32 [[T1]], [[Y:%.*]]
+; CHECK-NEXT:    [[OR2:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR1]], i32 [[OR1]], i32 5)
+; CHECK-NEXT:    ret i32 [[OR2]]
+;
+
+  %t1 = zext i8 %x to i32
+  %or1 = or i32 %t1, %y
+  %shr = lshr i32 %or1, 27
+  %shl = shl i32 %or1, 5
+  %or2 = or i32 %shr, %shl
+  ret i32 %or2
+}
+
+define i32 @fsh_rotate_18(i8 %x, i32 %y) {
+; CHECK-LABEL: @fsh_rotate_18(
+; CHECK-NEXT:    [[T1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[OR1:%.*]] = or i32 [[T1]], [[Y:%.*]]
+; CHECK-NEXT:    [[OR2:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR1]], i32 [[OR1]], i32 18)
+; CHECK-NEXT:    ret i32 [[OR2]]
+;
+
+  %t1 = zext i8 %x to i32
+  %or1 = or i32 %t1, %y
+  %shr = lshr i32 %or1, 14
+  %shl = shl i32 %or1, 18
+  %or2 = or i32 %shr, %shl
+  ret i32 %or2
+}
+
+define i32 @fsh_load_rotate_12(ptr %data) {
+; CHECK-LABEL: @fsh_load_rotate_12(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load i8, ptr [[DATA:%.*]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP0]] to i32
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[CONV]], 24
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 1
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
+; CHECK-NEXT:    [[CONV2:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    [[SHL3:%.*]] = shl nuw nsw i32 [[CONV2]], 16
+; CHECK-NEXT:    [[OR:%.*]] = or i32 [[SHL3]], [[SHL]]
+; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 2
+; CHECK-NEXT:    [[TMP2:%.*]] = load i8, ptr [[ARRAYIDX4]], align 1
+; CHECK-NEXT:    [[CONV5:%.*]] = zext i8 [[TMP2]] to i32
+; CHECK-NEXT:    [[SHL6:%.*]] = shl nuw nsw i32 [[CONV5]], 8
+; CHECK-NEXT:    [[OR7:%.*]] = or i32 [[OR]], [[SHL6]]
+; CHECK-NEXT:    [[ARRAYIDX8:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 3
+; CHECK-NEXT:    [[TMP3:%.*]] = load i8, ptr [[ARRAYIDX8]], align 1
+; CHECK-NEXT:    [[CONV9:%.*]] = zext i8 [[TMP3]] to i32
+; CHECK-NEXT:    [[OR10:%.*]] = or i32 [[OR7]], [[CONV9]]
+; CHECK-NEXT:    [[OR15:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR10]], i32 [[OR10]], i32 12)
+; CHECK-NEXT:    ret i32 [[OR15]]
+;
+
+entry:
+  %0 = load i8, ptr %data
+  %conv = zext i8 %0 to i32
+  %shl = shl nuw i32 %conv, 24
+  %arrayidx1 = getelementptr inbounds i8, ptr %data, i64 1
+  %1 = load i8, ptr %arrayidx1
+  %conv2 = zext i8 %1 to i32
+  %shl3 = shl nuw nsw i32 %conv2, 16
+  %or = or i32 %shl3, %shl
+  %arrayidx4 = getelementptr inbounds i8, ptr %data, i64 2
+  %2 = load i8, ptr %arrayidx4
+  %conv5 = zext i8 %2 to i32
+  %shl6 = shl nuw nsw i32 %conv5, 8
+  %or7 = or i32 %or, %shl6
+  %arrayidx8 = getelementptr inbounds i8, ptr %data, i64 3
+  %3 = load i8, ptr %arrayidx8
+  %conv9 = zext i8 %3 to i32
+  %or10 = or i32 %or7, %conv9
+  %shr = lshr i32 %or10, 20
+  %shl7 = shl i32 %or10, 12
+  %or15 = or i32 %shr, %shl7
+  ret i32 %or15
+}
+
+define i32 @fsh_load_rotate_25(ptr %data) {
+; CHECK-LABEL: @fsh_load_rotate_25(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load i8, ptr [[DATA:%.*]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP0]] to i32
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[CONV]], 24
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 1
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
+; CHECK-NEXT:    [[CONV2:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    [[SHL3:%.*]] = shl nuw nsw i32 [[CONV2]], 16
+; CHECK-NEXT:    [[OR:%.*]] = or i32 [[SHL3]], [[SHL]]
+; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 2
+; CHECK-NEXT:    [[TMP2:%.*]] = load i8, ptr [[ARRAYIDX4]], align 1
+; CHECK-NEXT:    [[CONV5:%.*]] = zext i8 [[TMP2]] to i32
+; CHECK-NEXT:    [[SHL6:%.*]] = shl nuw nsw i32 [[CONV5]], 8
+; CHECK-NEXT:    [[OR7:%.*]] = or i32 [[OR]], [[SHL6]]
+; CHECK-NEXT:    [[ARRAYIDX8:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 3
+; CHECK-NEXT:    [[TMP3:%.*]] = load i8, ptr [[ARRAYIDX8]], align 1
+; CHECK-NEXT:    [[CONV9:%.*]] = zext i8 [[TMP3]] to i32
+; CHECK-NEXT:    [[OR10:%.*]] = or i32 [[OR7]], [[CONV9]]
+; CHECK-NEXT:    [[OR15:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR10]], i32 [[OR10]], i32 25)
+; CHECK-NEXT:    ret i32 [[OR15]]
+;
+
+entry:
+  %0 = load i8, ptr %data
+  %conv = zext i8 %0 to i32
+  %shl = shl nuw i32 %conv, 24
+  %arrayidx1 = getelementptr inbounds i8, ptr %data, i64 1
+  %1 = load i8, ptr %arrayidx1
+  %conv2 = zext i8 %1 to i32
+  %shl3 = shl nuw nsw i32 %conv2, 16
+  %or = or i32 %shl3, %shl
+  %arrayidx4 = getelementptr inbounds i8, ptr %data, i64 2
+  %2 = load i8, ptr %arrayidx4
+  %conv5 = zext i8 %2 to i32
+  %shl6 = shl nuw nsw i32 %conv5, 8
+  %or7 = or i32 %or, %shl6
+  %arrayidx8 = getelementptr inbounds i8, ptr %data, i64 3
+  %3 = load i8, ptr %arrayidx8
+  %conv9 = zext i8 %3 to i32
+  %or10 = or i32 %or7, %conv9
+  %shr = lshr i32 %or10, 7
+  %shl7 = shl i32 %or10, 25
+  %or15 = or i32 %shr, %shl7
+  ret i32 %or15
+}
+
 define <2 x i31> @fshr_mask_args_same_vector(<2 x i31> %a) {
 ; CHECK-LABEL: @fshr_mask_args_same_vector(
 ; CHECK-NEXT:    [[T3:%.*]] = shl <2 x i31> [[A:%.*]], <i31 10, i31 10>



More information about the llvm-commits mailing list