[llvm] [GlobalIsel] Combine G_SELECT (PR #74845)
Thorsten Schütt via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 11 01:33:47 PST 2023
================
@@ -6318,3 +6246,343 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
MI.getOperand(2).setReg(LHSReg);
Observer.changedInstr(MI);
}
+
+bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
+ LLT SrcTy = MRI.getType(Src);
+ if (SrcTy.isFixedVector())
+ return isConstantSplatVector(Src, 1, AllowUndefs);
+ if (SrcTy.isScalar()) {
+ if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+ return true;
+ auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+ return IConstant && IConstant->Value == 1;
+ }
+ return false; // scalable vector
+}
+
+bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
+ LLT SrcTy = MRI.getType(Src);
+ if (SrcTy.isFixedVector())
+ return isConstantSplatVector(Src, 0, AllowUndefs);
+ if (SrcTy.isScalar()) {
+ if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+ return true;
+ auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+ return IConstant && IConstant->Value == 0;
+ }
+ return false; // scalable vector
+}
+
+// Ignores COPYs during conformance checks.
+// FIXME scalable vectors.
+bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
+ bool AllowUndefs) {
+ GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+ if (!BuildVector)
+ return false;
+ unsigned NumSources = BuildVector->getNumSources();
+
+ for (unsigned I = 0; I < NumSources; ++I) {
+ GImplicitDef *ImplicitDef =
+ getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
+ if (ImplicitDef && AllowUndefs)
+ continue;
+ if (ImplicitDef && !AllowUndefs)
+ return false;
+ std::optional<ValueAndVReg> IConstant =
+ getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+ if (IConstant && IConstant->Value == SplatValue)
+ continue;
+ return false;
+ }
+ return true;
+}
+
+// Ignores COPYs during lookups.
+// FIXME scalable vectors
+std::optional<APInt>
+CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
+ auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+ if (IConstant)
+ return IConstant->Value;
+
+ GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+ if (!BuildVector)
+ return std::nullopt;
+ unsigned NumSources = BuildVector->getNumSources();
+
+ std::optional<APInt> Value = std::nullopt;
+ for (unsigned I = 0; I < NumSources; ++I) {
+ std::optional<ValueAndVReg> IConstant =
+ getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+ if (!IConstant)
+ return std::nullopt;
+ if (!Value)
+ Value = IConstant->Value;
+ else if (*Value != IConstant->Value)
+ return std::nullopt;
+ }
+ return Value;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
+ BuildFnTy &MatchInfo) {
+ uint32_t Flags = Select->getFlags();
+ Register Dest = Select->getReg(0);
+ Register Cond = Select->getCondReg();
+ Register True = Select->getTrueReg();
+ Register False = Select->getFalseReg();
+ LLT CondTy = MRI.getType(Select->getCondReg());
+ LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+ // Either both are scalars or both are vectors.
+ std::optional<APInt> TrueOpt = getConstantOrConstantSplatVector(True);
+ std::optional<APInt> FalseOpt = getConstantOrConstantSplatVector(False);
+
+ if (!TrueOpt || !FalseOpt)
+ return false;
+
+ // These are only the splat values.
+ APInt TrueValue = *TrueOpt;
+ APInt FalseValue = *FalseOpt;
+
+ // Boolean or fixed vector of booleans.
+ if (CondTy.isScalableVector() ||
+ (CondTy.isFixedVector() &&
+ CondTy.getElementType().getScalarSizeInBits() != 1) ||
+ CondTy.getScalarSizeInBits() != 1)
+ return false;
+
+ // select Cond, 1, 0 --> zext (Cond)
+ if (TrueValue.isOne() && FalseValue.isZero()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ B.buildZExtOrTrunc(Dest, Cond);
+ };
+ return true;
+ }
+
+ // select Cond, -1, 0 --> sext (Cond)
+ if (TrueValue.isAllOnes() && FalseValue.isZero()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ B.buildSExtOrTrunc(Dest, Cond);
+ };
+ return true;
+ }
+
+ // select Cond, 0, 1 --> zext (!Cond)
+ if (TrueValue.isZero() && FalseValue.isOne()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Inner = MRI.createGenericVirtualRegister(CondTy);
+ B.buildNot(Inner, Cond);
+ B.buildZExtOrTrunc(Dest, Inner);
+ };
+ return true;
+ }
+
+ // select Cond, 0, -1 --> sext (!Cond)
+ if (TrueValue.isZero() && FalseValue.isAllOnes()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Inner = MRI.createGenericVirtualRegister(CondTy);
+ B.buildNot(Inner, Cond);
+ B.buildSExtOrTrunc(Dest, Inner);
+ };
+ return true;
+ }
+
+ // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+ if (TrueValue - 1 == FalseValue) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildZExtOrTrunc(Inner, Cond);
+ B.buildAdd(Dest, Inner, False);
+ };
+ return true;
+ }
+
+ // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+ if (TrueValue + 1 == FalseValue) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildSExtOrTrunc(Inner, Cond);
+ B.buildAdd(Dest, Inner, False);
+ };
+ return true;
+ }
+
+ // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
+ if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildZExtOrTrunc(Inner, Cond);
+ // The shift amount must be scalar.
+ LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
+ auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
+ B.buildShl(Dest, Inner, ShAmtC, Flags);
+ };
+ return true;
+ }
+
+ // select Cond, -1, C --> or (sext Cond), C
+ if (TrueValue.isAllOnes()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildSExtOrTrunc(Inner, Cond);
+ B.buildOr(Dest, Inner, False, Flags);
+ };
+ return true;
+ }
+
+ // select Cond, C, -1 --> or (sext (not Cond)), C
+ if (FalseValue.isAllOnes()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Not = MRI.createGenericVirtualRegister(CondTy);
+ B.buildNot(Not, Cond);
+ Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildSExtOrTrunc(Inner, Not);
+ B.buildOr(Dest, Inner, True, Flags);
+ };
+ return true;
+ }
+
+ return false;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
+ BuildFnTy &MatchInfo) {
+ uint32_t Flags = Select->getFlags();
+ Register DstReg = Select->getReg(0);
+ Register Cond = Select->getCondReg();
+ Register True = Select->getTrueReg();
+ Register False = Select->getFalseReg();
+ LLT CondTy = MRI.getType(Select->getCondReg());
+ LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+ // Boolean or fixed vector of booleans.
+ if (CondTy.isScalableVector() ||
+ (CondTy.isFixedVector() &&
+ CondTy.getElementType().getScalarSizeInBits() != 1) ||
+ CondTy.getScalarSizeInBits() != 1)
+ return false;
+
+ // select Cond, Cond, F --> or Cond, F
+ // select Cond, 1, F --> or Cond, F
+ if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildZExtOrTrunc(Ext, Cond);
+ B.buildOr(DstReg, Ext, False, Flags);
+ };
+ return true;
+ }
+
+ // select Cond, T, Cond --> and Cond, T
+ // select Cond, T, 0 --> and Cond, T
+ if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+ B.buildZExtOrTrunc(Ext, Cond);
+ B.buildAnd(DstReg, Ext, True);
+ };
+ return true;
+ }
+
+ // select Cond, T, 1 --> or (not Cond), T
+ if (isOneOrOneSplat(False, /* AllowUndefs */ true)) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ // First the not.
+ Register Inner = MRI.createGenericVirtualRegister(CondTy);
+ B.buildNot(Inner, Cond);
----------------
tschuett wrote:
The`Ìnner` Is dead.
https://github.com/llvm/llvm-project/pull/74845
More information about the llvm-commits
mailing list