[llvm] [InstCombine] Convert or concat to fshl if opposite or concat exists (PR #68502)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 9 23:40:08 PDT 2023
================
@@ -2727,105 +2727,161 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I,
}
/// Match UB-safe variants of the funnel shift intrinsic.
-static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
+static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
+ const DominatorTree &DT) {
// TODO: Can we reduce the code duplication between this and the related
// rotate matching code under visitSelect and visitTrunc?
unsigned Width = Or.getType()->getScalarSizeInBits();
- // First, find an or'd pair of opposite shifts:
- // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
- BinaryOperator *Or0, *Or1;
- if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
- !match(Or.getOperand(1), m_BinOp(Or1)))
+ Instruction *Or0, *Or1;
+ if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
+ !match(Or.getOperand(1), m_Instruction(Or1)))
return nullptr;
- Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
- Or0->getOpcode() == Or1->getOpcode())
- return nullptr;
+ bool IsFshl = true; // Sub on LSHR.
+ SmallVector<Value *, 3> FShiftArgs;
- // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
- if (Or0->getOpcode() == BinaryOperator::LShr) {
- std::swap(Or0, Or1);
- std::swap(ShVal0, ShVal1);
- std::swap(ShAmt0, ShAmt1);
- }
- assert(Or0->getOpcode() == BinaryOperator::Shl &&
- Or1->getOpcode() == BinaryOperator::LShr &&
- "Illegal or(shift,shift) pair");
-
- // Match the shift amount operands for a funnel shift pattern. This always
- // matches a subtraction on the R operand.
- auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
- // Check for constant shift amounts that sum to the bitwidth.
- const APInt *LI, *RI;
- if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
- if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
- return ConstantInt::get(L->getType(), *LI);
-
- Constant *LC, *RC;
- if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
- match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
- match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
- match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
- return ConstantExpr::mergeUndefsWith(LC, RC);
-
- // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
- // We limit this to X < Width in case the backend re-expands the intrinsic,
- // and has to reintroduce a shift modulo operation (InstCombine might remove
- // it after this fold). This still doesn't guarantee that the final codegen
- // will match this original pattern.
- if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
- KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
- return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+ // First, find an or'd pair of opposite shifts:
+ // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
+ if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
+ Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+ if (!match(Or0,
+ m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+ !match(Or1,
+ m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+ Or0->getOpcode() == Or1->getOpcode())
+ return nullptr;
+
+ // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(ShVal0, ShVal1);
+ std::swap(ShAmt0, ShAmt1);
}
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
+
+ // Match the shift amount operands for a funnel shift pattern. This always
+ // matches a subtraction on the R operand.
+ auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
+ // Check for constant shift amounts that sum to the bitwidth.
+ const APInt *LI, *RI;
+ if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
+ if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
+ return ConstantInt::get(L->getType(), *LI);
+
+ Constant *LC, *RC;
+ if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
+ match(L,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+ match(R,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+ match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
+ return ConstantExpr::mergeUndefsWith(LC, RC);
+
+ // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
+ // We limit this to X < Width in case the backend re-expands the
+ // intrinsic, and has to reintroduce a shift modulo operation (InstCombine
+ // might remove it after this fold). This still doesn't guarantee that the
+ // final codegen will match this original pattern.
+ if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
+ KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+ return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+ }
+
+ // For non-constant cases, the following patterns currently only work for
+ // rotation patterns.
+ // TODO: Add general funnel-shift compatible patterns.
+ if (ShVal0 != ShVal1)
+ return nullptr;
+
+ // For non-constant cases we don't support non-pow2 shift masks.
+ // TODO: Is it worth matching urem as well?
+ if (!isPowerOf2_32(Width))
+ return nullptr;
+
+ // The shift amount may be masked with negation:
+ // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+ Value *X;
+ unsigned Mask = Width - 1;
+ if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
+ match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
+ return X;
+
+ // Similar to above, but the shift amount may be extended after masking,
+ // so return the extended value as the parameter for the intrinsic.
+ if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+ match(R,
+ m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
+ m_SpecificInt(Mask))))
+ return L;
+
+ if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+ match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
+ return L;
- // For non-constant cases, the following patterns currently only work for
- // rotation patterns.
- // TODO: Add general funnel-shift compatible patterns.
- if (ShVal0 != ShVal1)
return nullptr;
+ };
- // For non-constant cases we don't support non-pow2 shift masks.
- // TODO: Is it worth matching urem as well?
- if (!isPowerOf2_32(Width))
+ Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
+ if (!ShAmt) {
+ ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
+ IsFshl = false; // Sub on SHL.
+ }
+ if (!ShAmt)
return nullptr;
- // The shift amount may be masked with negation:
- // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
- Value *X;
- unsigned Mask = Width - 1;
- if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
- match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
- return X;
+ FShiftArgs = {ShVal0, ShVal1, ShAmt};
- // Similar to above, but the shift amount may be extended after masking,
- // so return the extended value as the parameter for the intrinsic.
- if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
- match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
- m_SpecificInt(Mask))))
- return L;
+ } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) {
+ // If there are two 'or' instructions concat variables in opposite order,
+ // the latter one can be safely convert to fshl.
+ //
+ // LowHigh = or (shl (zext Low), Width - ZextHighShlAmt), (zext High)
+ // HighLow = or (shl (zext High), ZextHighShlAmt), (zext Low)
+ // ->
+ // HighLow = fshl LowHigh, LowHigh, ZextHighShlAmt
+ if (!isa<ZExtInst>(Or1))
+ std::swap(Or0, Or1);
+
+ Value *High, *ZextHigh, *Low;
+ const APInt *ZextHighShlAmt;
+ if (!match(Or0,
+ m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt)))))
+ return nullptr;
- if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
- match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
- return L;
+ if (!match(Or1, m_ZExt(m_Value(Low))) ||
+ !match(ZextHigh, m_ZExt(m_Value(High))))
+ return nullptr;
- return nullptr;
- };
+ unsigned HighSize = High->getType()->getScalarSizeInBits();
+ unsigned LowSize = Low->getType()->getScalarSizeInBits();
+ if (*ZextHighShlAmt != LowSize || HighSize + LowSize != Width)
+ return nullptr;
- Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
- bool IsFshl = true; // Sub on LSHR.
- if (!ShAmt) {
- ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
- IsFshl = false; // Sub on SHL.
+ for (User *U : ZextHigh->users()) {
+ Value *X, *Y;
+ if (!match(U, m_Or(m_Value(X), m_Value(Y))))
+ continue;
+
+ if (!isa<ZExtInst>(Y))
+ std::swap(X, Y);
+
+ if (match(X, m_Shl(m_Specific(Or1), m_SpecificInt(HighSize))) &&
+ match(Y, m_Specific(ZextHigh)) && DT.dominates(U, &Or)) {
+ FShiftArgs = {U, U, ConstantInt::get(Or0->getType(), *ZextHighShlAmt)};
+ break;
+ }
+ }
----------------
HaohaiWen wrote:
Given IR like this.
```
%zext.x = zext i8 %x to i32
%slx = shl nuw i32 %zext.x, 24
%zext.y = zext i24 %y to i32
%xy = or i32 %zext.y, %slx
store i32 %xy, ptr %addr, align 4
%sly = shl nuw i32 %zext.y, 8
%yx = or i32 %zext.x, %sly
%and.i3 = and i32 %yx, 16711935
%shl1.i = shl nuw i32 %and.i3, 8
%and2.i = and i32 %yx, -16711936
%shr3.i = lshr i32 %and2.i, 8
%or4.i = or i32 %shl1.i, %shr3.i
%and5.i = and i32 %or4.i, 252645135
%shl6.i = shl nuw i32 %and5.i, 4
%and7.i = and i32 %or4.i, -252645136
%shr8.i = lshr i32 %and7.i, 4
%or9.i = or i32 %shl6.i, %shr8.i
%and10.i = and i32 %or9.i, 858993459
%shl11.i = shl nuw i32 %and10.i, 2
%and12.i = and i32 %or9.i, -858993460
%shr13.i = lshr i32 %and12.i, 2
%or14.i = or i32 %shl11.i, %shr13.i
%and15.i = and i32 %or14.i, 1431655765
%shl16.i = shl nuw i32 %and15.i, 1
%and17.i = and i32 %or14.i, -1431655766
%shr18.i = lshr i32 %and17.i, 1
%or19.i = or i32 %shl16.i, %shr18.i
```
Apparently it can be optimized to
```
%zext.x = zext i8 %x to i32
%slx = shl nuw i32 %zext.x, 24
%zext.y = zext i24 %y to i32
%xy = or i32 %zext.y, %slx
store i32 %xy, ptr %addr, align 4
%res = call i32 @llvm.bitreverse.i32(i32 %xy)
```
This requires first optimize
```
%sly = shl nuw i32 %zext.y, 8
%yx = or i32 %zext.x, %sly
```
to
` fshl %xy, %xy, 8
`
Then teach InstCombine to optimize the rest to bitreverse.
If we do this funnel shift optimization in DAG combiner. There's no chance to optimize the rest to bitreverse.
https://github.com/llvm/llvm-project/pull/68502
More information about the llvm-commits
mailing list