[llvm] [GlobalIsel] Combine selects with constants (PR #76089)

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 2 08:04:08 PST 2024


Thorsten =?utf-8?q?Schütt?= <schuett at gmail.com>,
Thorsten =?utf-8?q?Schütt?= <schuett at gmail.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/76089 at github.com>


================
@@ -6318,3 +6262,300 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
   MI.getOperand(2).setReg(LHSReg);
   Observer.changedInstr(MI);
 }
+
+bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 1, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 1;
+  }
+  return false; // scalable vector
+}
+
+bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 0, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 0;
+  }
+  return false; // scalable vector
+}
+
+// Ignores COPYs during conformance checks.
+// FIXME scalable vectors.
+bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
+                                           bool AllowUndefs) {
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return false;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  for (unsigned I = 0; I < NumSources; ++I) {
+    GImplicitDef *ImplicitDef =
+        getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
+    if (ImplicitDef && AllowUndefs)
+      continue;
+    if (ImplicitDef && !AllowUndefs)
+      return false;
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (IConstant && IConstant->Value == SplatValue)
+      continue;
+    return false;
+  }
+  return true;
+}
+
+// Ignores COPYs during lookups.
+// FIXME scalable vectors
+std::optional<APInt>
+CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
+  auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+  if (IConstant)
+    return IConstant->Value;
+
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return std::nullopt;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  std::optional<APInt> Value = std::nullopt;
+  for (unsigned I = 0; I < NumSources; ++I) {
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (!IConstant)
+      return std::nullopt;
+    if (!Value)
+      Value = IConstant->Value;
+    else if (*Value != IConstant->Value)
+      return std::nullopt;
+  }
+  return Value;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register Dest = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // We only do this combine for scalar boolean conditions.
+  if (CondTy != LLT::scalar(1))
+    return false;
+
+  // Both are scalars.
+  std::optional<ValueAndVReg> TrueOpt =
+      getIConstantVRegValWithLookThrough(True, MRI);
+  std::optional<ValueAndVReg> FalseOpt =
+      getIConstantVRegValWithLookThrough(False, MRI);
+
+  if (!TrueOpt || !FalseOpt)
+    return false;
+
+  APInt TrueValue = TrueOpt->Value;
+  APInt FalseValue = FalseOpt->Value;
+
+  // select Cond, 1, 0 --> zext (Cond)
+  if (TrueValue.isOne() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildZExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, -1, 0 --> sext (Cond)
+  if (TrueValue.isAllOnes() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildSExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, 0, 1 --> zext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isOne()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildZExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, 0, -1 --> sext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildSExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+  if (TrueValue - 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+  if (TrueValue + 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
+  if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      // The shift amount must be scalar.
+      LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
+      auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
+      B.buildShl(Dest, Inner, ShAmtC, Flags);
+    };
+    return true;
+  }
+  // select Cond, -1, C --> or (sext Cond), C
+  if (TrueValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
----------------
aemerson wrote:

It's idiomatic for GlobalISel code to directly pass LLT to the builder when possible, you'll notice that we tend to rarely directly call `createGenericVirtualRegister()`. That said, if you feel very strongly about this I won't block this patch on that basis, but you should also be aware that this code will be in a minority w.r.t the rest of the codebase.

https://github.com/llvm/llvm-project/pull/76089


More information about the llvm-commits mailing list