[llvm] [InstCombine] Convert or concat to fshl if opposite or concat exists (PR #68502)

via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 7 20:16:46 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

<details>
<summary>Changes</summary>

If there are two 'or' instructions concat variables in opposite order
and the first 'or' dominates the second one, the second 'or' can be
optimized to fshl to rotate shift first 'or'. This can eliminate an shl
and expose more optimization opportunity for bswap/bitreverse.

---
Full diff: https://github.com/llvm/llvm-project/pull/68502.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+134-78) 
- (modified) llvm/test/Transforms/InstCombine/funnel.ll (+42) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index cbdab3e9c5fb91d..09017e46ee3bd4d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2727,105 +2727,161 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I,
 }
 
 /// Match UB-safe variants of the funnel shift intrinsic.
-static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
+static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
+                                     const DominatorTree &DT) {
   // TODO: Can we reduce the code duplication between this and the related
   // rotate matching code under visitSelect and visitTrunc?
   unsigned Width = Or.getType()->getScalarSizeInBits();
 
-  // First, find an or'd pair of opposite shifts:
-  // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
-  BinaryOperator *Or0, *Or1;
-  if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
-      !match(Or.getOperand(1), m_BinOp(Or1)))
+  Instruction *Or0, *Or1;
+  if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
+      !match(Or.getOperand(1), m_Instruction(Or1)))
     return nullptr;
 
-  Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
-  if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
-      !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
-      Or0->getOpcode() == Or1->getOpcode())
-    return nullptr;
+  bool IsFshl = true; // Sub on LSHR.
+  SmallVector<Value *, 3> FShiftArgs;
 
-  // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
-  if (Or0->getOpcode() == BinaryOperator::LShr) {
-    std::swap(Or0, Or1);
-    std::swap(ShVal0, ShVal1);
-    std::swap(ShAmt0, ShAmt1);
-  }
-  assert(Or0->getOpcode() == BinaryOperator::Shl &&
-         Or1->getOpcode() == BinaryOperator::LShr &&
-         "Illegal or(shift,shift) pair");
-
-  // Match the shift amount operands for a funnel shift pattern. This always
-  // matches a subtraction on the R operand.
-  auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
-    // Check for constant shift amounts that sum to the bitwidth.
-    const APInt *LI, *RI;
-    if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
-      if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
-        return ConstantInt::get(L->getType(), *LI);
-
-    Constant *LC, *RC;
-    if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
-        match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
-        match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
-        match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
-      return ConstantExpr::mergeUndefsWith(LC, RC);
-
-    // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
-    // We limit this to X < Width in case the backend re-expands the intrinsic,
-    // and has to reintroduce a shift modulo operation (InstCombine might remove
-    // it after this fold). This still doesn't guarantee that the final codegen
-    // will match this original pattern.
-    if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
-      KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
-      return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+  // First, find an or'd pair of opposite shifts:
+  // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
+  if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
+    Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+    if (!match(Or0,
+               m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+        !match(Or1,
+               m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+        Or0->getOpcode() == Or1->getOpcode())
+      return nullptr;
+
+    // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+    if (Or0->getOpcode() == BinaryOperator::LShr) {
+      std::swap(Or0, Or1);
+      std::swap(ShVal0, ShVal1);
+      std::swap(ShAmt0, ShAmt1);
     }
+    assert(Or0->getOpcode() == BinaryOperator::Shl &&
+           Or1->getOpcode() == BinaryOperator::LShr &&
+           "Illegal or(shift,shift) pair");
+
+    // Match the shift amount operands for a funnel shift pattern. This always
+    // matches a subtraction on the R operand.
+    auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
+      // Check for constant shift amounts that sum to the bitwidth.
+      const APInt *LI, *RI;
+      if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
+        if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
+          return ConstantInt::get(L->getType(), *LI);
+
+      Constant *LC, *RC;
+      if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
+          match(L,
+                m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+          match(R,
+                m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+          match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
+        return ConstantExpr::mergeUndefsWith(LC, RC);
+
+      // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
+      // We limit this to X < Width in case the backend re-expands the
+      // intrinsic, and has to reintroduce a shift modulo operation (InstCombine
+      // might remove it after this fold). This still doesn't guarantee that the
+      // final codegen will match this original pattern.
+      if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
+        KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+        return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+      }
+
+      // For non-constant cases, the following patterns currently only work for
+      // rotation patterns.
+      // TODO: Add general funnel-shift compatible patterns.
+      if (ShVal0 != ShVal1)
+        return nullptr;
+
+      // For non-constant cases we don't support non-pow2 shift masks.
+      // TODO: Is it worth matching urem as well?
+      if (!isPowerOf2_32(Width))
+        return nullptr;
+
+      // The shift amount may be masked with negation:
+      // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+      Value *X;
+      unsigned Mask = Width - 1;
+      if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
+          match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
+        return X;
+
+      // Similar to above, but the shift amount may be extended after masking,
+      // so return the extended value as the parameter for the intrinsic.
+      if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+          match(R,
+                m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
+                      m_SpecificInt(Mask))))
+        return L;
+
+      if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+          match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
+        return L;
 
-    // For non-constant cases, the following patterns currently only work for
-    // rotation patterns.
-    // TODO: Add general funnel-shift compatible patterns.
-    if (ShVal0 != ShVal1)
       return nullptr;
+    };
 
-    // For non-constant cases we don't support non-pow2 shift masks.
-    // TODO: Is it worth matching urem as well?
-    if (!isPowerOf2_32(Width))
+    Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
+    if (!ShAmt) {
+      ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
+      IsFshl = false; // Sub on SHL.
+    }
+    if (!ShAmt)
       return nullptr;
 
-    // The shift amount may be masked with negation:
-    // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
-    Value *X;
-    unsigned Mask = Width - 1;
-    if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
-        match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
-      return X;
+    FShiftArgs = {ShVal0, ShVal1, ShAmt};
 
-    // Similar to above, but the shift amount may be extended after masking,
-    // so return the extended value as the parameter for the intrinsic.
-    if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
-        match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
-                       m_SpecificInt(Mask))))
-      return L;
+  } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) {
+    // If there are two 'or' instructions concat variables in opposite order,
+    // the latter one can be safely convert to fshl.
+    //
+    // LowHigh = or (shl (zext Low), Width - ZextHighShlAmt), (zext High)
+    // HighLow = or (shl (zext High), ZextHighShlAmt), (zext Low)
+    // ->
+    // HighLow = fshl LowHigh, LowHigh, ZextHighShlAmt
+    if (!isa<ZExtInst>(Or1))
+      std::swap(Or0, Or1);
+
+    Value *High, *ZextHigh, *Low;
+    const APInt *ZextHighShlAmt;
+    if (!match(Or0,
+               m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt)))))
+      return nullptr;
 
-    if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
-        match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
-      return L;
+    if (!match(Or1, m_ZExt(m_Value(Low))) ||
+        !match(ZextHigh, m_ZExt(m_Value(High))))
+      return nullptr;
 
-    return nullptr;
-  };
+    unsigned HighSize = High->getType()->getScalarSizeInBits();
+    unsigned LowSize = Low->getType()->getScalarSizeInBits();
+    if (*ZextHighShlAmt != LowSize || HighSize + LowSize != Width)
+      return nullptr;
 
-  Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
-  bool IsFshl = true; // Sub on LSHR.
-  if (!ShAmt) {
-    ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
-    IsFshl = false; // Sub on SHL.
+    for (User *U : ZextHigh->users()) {
+      Value *X, *Y;
+      if (!match(U, m_Or(m_Value(X), m_Value(Y))))
+        continue;
+
+      if (!isa<ZExtInst>(Y))
+        std::swap(X, Y);
+
+      if (match(X, m_Shl(m_Specific(Or1), m_SpecificInt(HighSize))) &&
+          match(Y, m_Specific(ZextHigh)) && DT.dominates(U, &Or)) {
+        FShiftArgs = {U, U, ConstantInt::get(Or0->getType(), *ZextHighShlAmt)};
+        break;
+      }
+    }
   }
-  if (!ShAmt)
+
+  if (FShiftArgs.empty())
     return nullptr;
 
   Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
   Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
-  return CallInst::Create(F, {ShVal0, ShVal1, ShAmt});
+  return CallInst::Create(F, FShiftArgs);
 }
 
 /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
@@ -3319,7 +3375,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                                   /*MatchBitReversals*/ true))
     return BitOp;
 
-  if (Instruction *Funnel = matchFunnelShift(I, *this))
+  if (Instruction *Funnel = matchFunnelShift(I, *this, DT))
     return Funnel;
 
   if (Instruction *Concat = matchOrConcat(I, Builder))
diff --git a/llvm/test/Transforms/InstCombine/funnel.ll b/llvm/test/Transforms/InstCombine/funnel.ll
index 60ce49a1635623f..c5bd1aa7b4351bc 100644
--- a/llvm/test/Transforms/InstCombine/funnel.ll
+++ b/llvm/test/Transforms/InstCombine/funnel.ll
@@ -354,6 +354,48 @@ define <2 x i64> @fshl_select_vector(<2 x i64> %x, <2 x i64> %y, <2 x i64> %sham
   ret <2 x i64> %r
 }
 
+; Convert 'or concat' to fshl if opposite 'or concat' exists.
+
+define i32 @fshl_concat(i8 %x, i24 %y, ptr %addr) {
+; CHECK-LABEL: @fshl_concat(
+; CHECK-NEXT:    [[ZEXT_X:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[SLX:%.*]] = shl nuw i32 [[ZEXT_X]], 24
+; CHECK-NEXT:    [[ZEXT_Y:%.*]] = zext i24 [[Y:%.*]] to i32
+; CHECK-NEXT:    [[XY:%.*]] = or i32 [[SLX]], [[ZEXT_Y]]
+; CHECK-NEXT:    store i32 [[XY]], ptr [[ADDR:%.*]], align 4
+; CHECK-NEXT:    [[YX:%.*]] = call i32 @llvm.fshl.i32(i32 [[XY]], i32 [[XY]], i32 8)
+; CHECK-NEXT:    ret i32 [[YX]]
+;
+  %zext.x = zext i8 %x to i32
+  %slx = shl nuw i32 %zext.x, 24
+  %zext.y = zext i24 %y to i32
+  %xy = or i32 %zext.y, %slx
+  store i32 %xy, ptr %addr, align 4
+  %sly = shl nuw i32 %zext.y, 8
+  %yx = or i32 %zext.x, %sly
+  ret i32 %yx
+}
+
+define <2 x i32> @fshl_concat_vector(<2 x i8> %x, <2 x i24> %y, ptr %addr) {
+; CHECK-LABEL: @fshl_concat_vector(
+; CHECK-NEXT:    [[ZEXT_X:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[SLX:%.*]] = shl nuw <2 x i32> [[ZEXT_X]], <i32 24, i32 24>
+; CHECK-NEXT:    [[ZEXT_Y:%.*]] = zext <2 x i24> [[Y:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[XY:%.*]] = or <2 x i32> [[SLX]], [[ZEXT_Y]]
+; CHECK-NEXT:    store <2 x i32> [[XY]], ptr [[ADDR:%.*]], align 4
+; CHECK-NEXT:    [[YX:%.*]] = call <2 x i32> @llvm.fshl.v2i32(<2 x i32> [[XY]], <2 x i32> [[XY]], <2 x i32> <i32 8, i32 8>)
+; CHECK-NEXT:    ret <2 x i32> [[YX]]
+;
+  %zext.x = zext <2 x i8> %x to <2 x i32>
+  %slx = shl nuw <2 x i32> %zext.x, <i32 24, i32 24>
+  %zext.y = zext <2 x i24> %y to <2 x i32>
+  %xy = or <2 x i32> %slx, %zext.y
+  store <2 x i32> %xy, ptr %addr, align 4
+  %sly = shl nuw <2 x i32> %zext.y, <i32 8, i32 8>
+  %yx = or <2 x i32> %sly, %zext.x
+  ret <2 x i32> %yx
+}
+
 ; Negative test - an oversized shift in the narrow type would produce the wrong value.
 
 define i8 @unmasked_shlop_unmasked_shift_amount(i32 %x, i32 %y, i32 %shamt) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/68502


More information about the llvm-commits mailing list