[llvm] [GlobalIsel] Combine G_SELECT (PR #74845)

Thorsten Schütt via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 11 01:33:47 PST 2023


================
@@ -6318,3 +6246,343 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
   MI.getOperand(2).setReg(LHSReg);
   Observer.changedInstr(MI);
 }
+
+bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 1, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 1;
+  }
+  return false; // scalable vector
+}
+
+bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 0, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 0;
+  }
+  return false; // scalable vector
+}
+
+// Ignores COPYs during conformance checks.
+// FIXME scalable vectors.
+bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
+                                           bool AllowUndefs) {
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return false;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  for (unsigned I = 0; I < NumSources; ++I) {
+    GImplicitDef *ImplicitDef =
+        getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
+    if (ImplicitDef && AllowUndefs)
+      continue;
+    if (ImplicitDef && !AllowUndefs)
+      return false;
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (IConstant && IConstant->Value == SplatValue)
+      continue;
+    return false;
+  }
+  return true;
+}
+
+// Ignores COPYs during lookups.
+// FIXME scalable vectors
+std::optional<APInt>
+CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
+  auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+  if (IConstant)
+    return IConstant->Value;
+
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return std::nullopt;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  std::optional<APInt> Value = std::nullopt;
+  for (unsigned I = 0; I < NumSources; ++I) {
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (!IConstant)
+      return std::nullopt;
+    if (!Value)
+      Value = IConstant->Value;
+    else if (*Value != IConstant->Value)
+      return std::nullopt;
+  }
+  return Value;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register Dest = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Either both are scalars or both are vectors.
+  std::optional<APInt> TrueOpt = getConstantOrConstantSplatVector(True);
+  std::optional<APInt> FalseOpt = getConstantOrConstantSplatVector(False);
+
+  if (!TrueOpt || !FalseOpt)
+    return false;
+
+  // These are only the splat values.
+  APInt TrueValue = *TrueOpt;
+  APInt FalseValue = *FalseOpt;
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, 1, 0 --> zext (Cond)
+  if (TrueValue.isOne() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildZExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, -1, 0 --> sext (Cond)
+  if (TrueValue.isAllOnes() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildSExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, 0, 1 --> zext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isOne()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildZExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, 0, -1 --> sext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildSExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+  if (TrueValue - 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+  if (TrueValue + 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
+  if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      // The shift amount must be scalar.
+      LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
+      auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
+      B.buildShl(Dest, Inner, ShAmtC, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, -1, C --> or (sext Cond), C
+  if (TrueValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildOr(Dest, Inner, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, C, -1 --> or (sext (not Cond)), C
+  if (FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Not = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Not, Cond);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Not);
+      B.buildOr(Dest, Inner, True, Flags);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register DstReg = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, Cond, F --> or Cond, F
+  // select Cond, 1, F    --> or Cond, F
+  if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildOr(DstReg, Ext, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, T, Cond --> and Cond, T
+  // select Cond, T, 0    --> and Cond, T
+  if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildAnd(DstReg, Ext, True);
+    };
+    return true;
+  }
+
+  // select Cond, T, 1 --> or (not Cond), T
+  if (isOneOrOneSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
----------------
tschuett wrote:

The`Ìnner` Is dead.

https://github.com/llvm/llvm-project/pull/74845


More information about the llvm-commits mailing list