[llvm] [GlobalIsel] Combine G_SELECT (PR #74845)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 8 06:53:56 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Thorsten Schütt (tschuett)

<details>
<summary>Changes</summary>

Cleanups, preparation for more combines, add known bits for constant conditions, combine selects where the false and true register are constants, and improve support for vector conditions.

AMDGPU supports vector conditions. X86 has a todo for vector conditions. AArch64 SVE supports SEL for vector conditions. How to implement vector conditions with NEON (with bsl), see arm64-vselect.ll ? Vector select asserts in the instruction selector.

buildNot does not support scalable vectors. We cannot create scalable constant vectors of -1 and there is no G_Not. AArch64 SVE has a NOT and a DUP for broadcasting. Something akin to G_CONSTANT_SPLAT, G_CONSTANT_VECTOR, G_SPLAT_VECTOR, G_BRODCAST, or G_HOMOGENOUS_VECTOR that takes an immediate and creates a (fixed or scalable) vector where all elements are the immediate might solve the buildNot challenge, facilitates new combines, pattern matching, and new selecting optimizations.

P.S. We need to support integer and float.

https://github.com/llvm/llvm-project/pull/74502

```c
<vscale x 4 x i32> splat (i32 -1)
```

---

Patch is 696.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74845.diff


31 Files Affected:

- (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+18-10) 
- (modified) llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h (+5-5) 
- (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+7-27) 
- (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp (+340-72) 
- (modified) llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp (+9) 
- (modified) llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir (+261-4) 
- (modified) llvm/test/CodeGen/AArch64/GlobalISel/postlegalizercombiner-select.mir (+1-1) 
- (modified) llvm/test/CodeGen/AArch64/andcompare.ll (+9-5) 
- (modified) llvm/test/CodeGen/AArch64/arm64-ccmp.ll (+73-51) 
- (modified) llvm/test/CodeGen/AArch64/call-rv-marker.ll (+409-38) 
- (modified) llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll (+8-12) 
- (modified) llvm/test/CodeGen/AArch64/stack-probing-dynamic-no-frame-setup.ll (+4-2) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fold-binop-into-select.mir (+17-25) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fshl.ll (+1088-1104) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fshr.ll (+1105-1311) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wqm.demote.ll (+32-16) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/lshr.ll (+41-36) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/saddsat.ll (+173-176) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/shl.ll (+59-52) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/usubsat.ll (+587-307) 
- (modified) llvm/test/CodeGen/AMDGPU/ctlz_zero_undef.ll (+18-14) 
- (modified) llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll (+8-6) 
- (modified) llvm/test/CodeGen/AMDGPU/fptrunc.ll (+24-22) 
- (modified) llvm/test/CodeGen/AMDGPU/fsqrt.f32.ll (+112-87) 
- (modified) llvm/test/CodeGen/AMDGPU/fsqrt.f64.ll (+238-260) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.inverse.ballot.i64.ll (+4-4) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.frexp.ll (+95-62) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+22-18) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+22-18) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+8-6) 
- (modified) llvm/test/CodeGen/AMDGPU/rsq.f64.ll (+450-500) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index a4e9c92b48976..f73e4ae7944df 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -423,16 +423,9 @@ class CombinerHelper {
   /// Return true if a G_STORE instruction \p MI is storing an undef value.
   bool matchUndefStore(MachineInstr &MI);
 
-  /// Return true if a G_SELECT instruction \p MI has an undef comparison.
-  bool matchUndefSelectCmp(MachineInstr &MI);
-
   /// Return true if a G_{EXTRACT,INSERT}_VECTOR_ELT has an out of range index.
   bool matchInsertExtractVecEltOutOfBounds(MachineInstr &MI);
 
-  /// Return true if a G_SELECT instruction \p MI has a constant comparison. If
-  /// true, \p OpIdx will store the operand index of the known selected value.
-  bool matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx);
-
   /// Replace an instruction with a G_FCONSTANT with value \p C.
   void replaceInstWithFConstant(MachineInstr &MI, double C);
 
@@ -771,9 +764,6 @@ class CombinerHelper {
   bool matchCombineFSubFpExtFNegFMulToFMadOrFMA(MachineInstr &MI,
                                                 BuildFnTy &MatchInfo);
 
-  /// Fold boolean selects to logical operations.
-  bool matchSelectToLogical(MachineInstr &MI, BuildFnTy &MatchInfo);
-
   bool matchCombineFMinMaxNaN(MachineInstr &MI, unsigned &Info);
 
   /// Transform G_ADD(x, G_SUB(y, x)) to y.
@@ -816,6 +806,9 @@ class CombinerHelper {
   // Given a binop \p MI, commute operands 1 and 2.
   void applyCommuteBinOpOperands(MachineInstr &MI);
 
+  // Combine selects.
+  bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -906,6 +899,21 @@ class CombinerHelper {
   /// select (fcmp uge x, 1.0) 1.0, x -> fminnm x, 1.0
   bool matchFPSelectToMinMax(Register Dst, Register Cond, Register TrueVal,
                              Register FalseVal, BuildFnTy &MatchInfo);
+
+  bool isOneOrOneSplat(Register Src, bool AllowUndefs);
+  bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
+  bool isConstantSplatVector(Register Src, int64_t SplatValue,
+                             bool AllowUndefs);
+  std::optional<APInt> getConstantOrConstantSplatVector(Register Src);
+
+  /// Try to combine selects with constant conditions.
+  bool tryCombineSelectConstantCondition(GSelect *Select, BuildFnTy &MatchInfo);
+
+  /// Try to combine selects with boolean conditions to logical operators.
+  bool tryFoldBoolSelectToLogic(GSelect *Select, BuildFnTy &MatchInfo);
+
+  /// Try to combine selects where the true and false values are constant.
+  bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
index e0101a5ac1ca8..7dec611c3e27e 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
@@ -1701,11 +1701,11 @@ class MachineIRBuilder {
 
   /// Build and insert a bitwise not,
   /// \p NegOne = G_CONSTANT -1
-  /// \p Res = G_OR \p Op0, NegOne
-  MachineInstrBuilder buildNot(const DstOp &Dst, const SrcOp &Src0) {
-    auto NegOne = buildConstant(Dst.getLLTTy(*getMRI()), -1);
-    return buildInstr(TargetOpcode::G_XOR, {Dst}, {Src0, NegOne});
-  }
+  /// \p Res = G_XOR \p Op0, NegOne
+  /// Or
+  /// \p NegOne = G_BUILD_VECTOR -1, -1, -1, ...
+  /// \p Res = G_XOR \p Op0, NegOne
+  MachineInstrBuilder buildNot(const DstOp &Dst, const SrcOp &Src0);
 
   /// Build and insert integer negation
   /// \p Zero = G_CONSTANT 0
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 77db371adaf77..5444c368e598b 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -419,31 +419,6 @@ def select_same_val: GICombineRule<
   (apply [{ Helper.replaceSingleDefInstWithOperand(*${root}, 2); }])
 >;
 
-// Fold (undef ? x : y) -> y
-def select_undef_cmp: GICombineRule<
-  (defs root:$dst),
-  (match (G_IMPLICIT_DEF $undef),
-         (G_SELECT $dst, $undef, $x, $y)),
-  (apply (GIReplaceReg $dst, $y))
->;
-
-// Fold (true ? x : y) -> x
-// Fold (false ? x : y) -> y
-def select_constant_cmp_matchdata : GIDefMatchData<"unsigned">;
-def select_constant_cmp: GICombineRule<
-  (defs root:$root, select_constant_cmp_matchdata:$matchinfo),
-  (match (wip_match_opcode G_SELECT):$root,
-    [{ return Helper.matchConstantSelectCmp(*${root}, ${matchinfo}); }]),
-  (apply [{ Helper.replaceSingleDefInstWithOperand(*${root}, ${matchinfo}); }])
->;
-
-def select_to_logical : GICombineRule<
-  (defs root:$root, build_fn_matchinfo:$matchinfo),
-  (match (wip_match_opcode G_SELECT):$root,
-    [{ return Helper.matchSelectToLogical(*${root}, ${matchinfo}); }]),
-  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])
->;
-
 // Fold (C op x) -> (x op C)
 // TODO: handle more isCommutable opcodes
 // TODO: handle compares (currently not marked as isCommutable)
@@ -1242,6 +1217,12 @@ def select_to_minmax: GICombineRule<
          [{ return Helper.matchSimplifySelectToMinMax(*${root}, ${info}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>;
 
+def match_selects : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_SELECT):$root,
+        [{ return Helper.matchSelect(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+
 // FIXME: These should use the custom predicate feature once it lands.
 def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
                                      undef_to_negative_one,
@@ -1281,8 +1262,7 @@ def width_reduction_combines : GICombineGroup<[reduce_shl_of_extend,
 
 def phi_combines : GICombineGroup<[extend_through_phis]>;
 
-def select_combines : GICombineGroup<[select_undef_cmp, select_constant_cmp,
-                                      select_to_logical]>;
+def select_combines : GICombineGroup<[match_selects]>;
 
 def trivial_combines : GICombineGroup<[copy_prop, mul_to_shl, add_p2i_to_ptradd,
                                        mul_by_neg_one, idempotent_prop]>;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 91a64d59e154d..b12f83e75859a 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -2611,12 +2611,6 @@ bool CombinerHelper::matchUndefStore(MachineInstr &MI) {
                       MRI);
 }
 
-bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) {
-  assert(MI.getOpcode() == TargetOpcode::G_SELECT);
-  return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(),
-                      MRI);
-}
-
 bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(MachineInstr &MI) {
   assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
           MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
@@ -2630,16 +2624,6 @@ bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(MachineInstr &MI) {
   return Idx->getZExtValue() >= VecTy.getNumElements();
 }
 
-bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) {
-  GSelect &SelMI = cast<GSelect>(MI);
-  auto Cst =
-      isConstantOrConstantSplatVector(*MRI.getVRegDef(SelMI.getCondReg()), MRI);
-  if (!Cst)
-    return false;
-  OpIdx = Cst->isZero() ? 3 : 2;
-  return true;
-}
-
 void CombinerHelper::eraseInst(MachineInstr &MI) { MI.eraseFromParent(); }
 
 bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
@@ -5940,62 +5924,6 @@ bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
   return false;
 }
 
-bool CombinerHelper::matchSelectToLogical(MachineInstr &MI,
-                                          BuildFnTy &MatchInfo) {
-  GSelect &Sel = cast<GSelect>(MI);
-  Register DstReg = Sel.getReg(0);
-  Register Cond = Sel.getCondReg();
-  Register TrueReg = Sel.getTrueReg();
-  Register FalseReg = Sel.getFalseReg();
-
-  auto *TrueDef = getDefIgnoringCopies(TrueReg, MRI);
-  auto *FalseDef = getDefIgnoringCopies(FalseReg, MRI);
-
-  const LLT CondTy = MRI.getType(Cond);
-  const LLT OpTy = MRI.getType(TrueReg);
-  if (CondTy != OpTy || OpTy.getScalarSizeInBits() != 1)
-    return false;
-
-  // We have a boolean select.
-
-  // select Cond, Cond, F --> or Cond, F
-  // select Cond, 1, F    --> or Cond, F
-  auto MaybeCstTrue = isConstantOrConstantSplatVector(*TrueDef, MRI);
-  if (Cond == TrueReg || (MaybeCstTrue && MaybeCstTrue->isOne())) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildOr(DstReg, Cond, FalseReg);
-    };
-    return true;
-  }
-
-  // select Cond, T, Cond --> and Cond, T
-  // select Cond, T, 0    --> and Cond, T
-  auto MaybeCstFalse = isConstantOrConstantSplatVector(*FalseDef, MRI);
-  if (Cond == FalseReg || (MaybeCstFalse && MaybeCstFalse->isZero())) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildAnd(DstReg, Cond, TrueReg);
-    };
-    return true;
-  }
-
- // select Cond, T, 1 --> or (not Cond), T
-  if (MaybeCstFalse && MaybeCstFalse->isOne()) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildOr(DstReg, MIB.buildNot(OpTy, Cond), TrueReg);
-    };
-    return true;
-  }
-
-  // select Cond, 0, F --> and (not Cond), F
-  if (MaybeCstTrue && MaybeCstTrue->isZero()) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildAnd(DstReg, MIB.buildNot(OpTy, Cond), FalseReg);
-    };
-    return true;
-  }
-  return false;
-}
-
 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
                                             unsigned &IdxToPropagate) {
   bool PropagateNaN;
@@ -6318,3 +6246,343 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
   MI.getOperand(2).setReg(LHSReg);
   Observer.changedInstr(MI);
 }
+
+bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 1, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 1;
+  }
+  return false; // scalable vector
+}
+
+bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 0, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 0;
+  }
+  return false; // scalable vector
+}
+
+// Ignores COPYs during conformance checks.
+// FIXME scalable vectors.
+bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
+                                           bool AllowUndefs) {
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return false;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  for (unsigned I = 0; I < NumSources; ++I) {
+    GImplicitDef *ImplicitDef =
+        getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
+    if (ImplicitDef && AllowUndefs)
+      continue;
+    if (ImplicitDef && !AllowUndefs)
+      return false;
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (IConstant && IConstant->Value == SplatValue)
+      continue;
+    return false;
+  }
+  return true;
+}
+
+// Ignores COPYs during lookups.
+// FIXME scalable vectors
+std::optional<APInt>
+CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
+  auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+  if (IConstant)
+    return IConstant->Value;
+
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return std::nullopt;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  std::optional<APInt> Value = std::nullopt;
+  for (unsigned I = 0; I < NumSources; ++I) {
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (!IConstant)
+      return std::nullopt;
+    if (!Value)
+      Value = IConstant->Value;
+    else if (*Value != IConstant->Value)
+      return std::nullopt;
+  }
+  return Value;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register Dest = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Either both are scalars or both are vectors.
+  std::optional<APInt> TrueOpt = getConstantOrConstantSplatVector(True);
+  std::optional<APInt> FalseOpt = getConstantOrConstantSplatVector(False);
+
+  if (!TrueOpt || !FalseOpt)
+    return false;
+
+  // These are only the splat values.
+  APInt TrueValue = *TrueOpt;
+  APInt FalseValue = *FalseOpt;
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, 1, 0 --> zext (Cond)
+  if (TrueValue.isOne() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildZExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, -1, 0 --> sext (Cond)
+  if (TrueValue.isAllOnes() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildSExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, 0, 1 --> zext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isOne()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildZExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, 0, -1 --> sext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildSExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+  if (TrueValue - 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+  if (TrueValue + 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
+  if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      // The shift amount must be scalar.
+      LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
+      auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
+      B.buildShl(Dest, Inner, ShAmtC, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, -1, C --> or (sext Cond), C
+  if (TrueValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildOr(Dest, Inner, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, C, -1 --> or (sext (not Cond)), C
+  if (FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Not = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Not, Cond);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Not);
+      B.buildOr(Dest, Inner, True, Flags);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register DstReg = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, Cond, F --> or Cond, F
+  // select Cond, 1, F    --> or Cond, F
+  if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildOr(DstReg, Ext, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, T, Cond --> and Cond, T
+  // select Cond, T, 0    --> and Cond, T
+  if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildAnd(DstReg, Ext, True);
+    };
+    return true;
+  }
+
+  // select Cond, T, 1 --> or (not Cond), T
+  if (isOneOrOneSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      // Then an ext to match the destination register.
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildOr(DstReg, Ext, True, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, 0, F --> and (not Cond), F
+  if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      // Then an ext to match the destination register.
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Inner);
+      B.buildAnd(DstReg, Ext, False);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+bool CombinerHelper::tryCombineSelectConstantCondition(GSelect *Select,
+                                                       BuildFnTy &MatchInfo) {
+  Register Dest = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+
+  KnownBits Known = KB->getKnownBits(Cond);
+  if (Known.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildCopy(Dest, False);
+    };
+    return true;
+  } else if (CondTy.isScalar...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74845


More information about the llvm-commits mailing list