[llvm] [GlobalIsel] Visit ICmp (PR #105991)

via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 25 08:58:20 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-globalisel

Author: Thorsten Schütt (tschuett)

<details>
<summary>Changes</summary>

inspired by simplifyICmpInst and simplifyICmpWithZero

---

Patch is 97.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105991.diff


15 Files Affected:

- (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+10) 
- (modified) llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h (+24) 
- (modified) llvm/include/llvm/CodeGen/GlobalISel/Utils.h (+26) 
- (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+44-7) 
- (modified) llvm/lib/CodeGen/GlobalISel/CMakeLists.txt (+1) 
- (added) llvm/lib/CodeGen/GlobalISel/CombinerHelperCompares.cpp (+305) 
- (modified) llvm/lib/CodeGen/GlobalISel/Utils.cpp (+323) 
- (modified) llvm/test/CodeGen/AArch64/GlobalISel/arm64-atomic.ll (+56-40) 
- (modified) llvm/test/CodeGen/AArch64/GlobalISel/arm64-pcsections.ll (+32-24) 
- (added) llvm/test/CodeGen/AArch64/GlobalISel/combine-visit-icmp.mir (+167) 
- (modified) llvm/test/CodeGen/AArch64/arm64-ccmp.ll (+18-42) 
- (added) llvm/test/CodeGen/AArch64/icmp2.ll (+295) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fdiv.f64.ll (+9-9) 
- (modified) llvm/test/CodeGen/AMDGPU/itofp.i128.ll (+68-86) 
- (modified) llvm/test/CodeGen/AMDGPU/rsq.f64.ll (+23-23) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 9b62d6067be39c..da9c7fdbd2a093 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -20,6 +20,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGenTypes/LowLevelType.h"
 #include "llvm/IR/InstrTypes.h"
@@ -299,6 +300,12 @@ class CombinerHelper {
   ///     $whatever = COPY $addr
   bool tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen = 0);
 
+  bool visitICmp(const MachineInstr &MI, BuildFnTy &MatchInfo);
+  bool matchSextOfICmp(const MachineInstr &MI, BuildFnTy &MatchInfo);
+  bool matchZextOfICmp(const MachineInstr &MI, BuildFnTy &MatchInfo);
+  /// Try hard to fold icmp with zero RHS because this is a common case.
+  bool matchCmpOfZero(const MachineInstr &MI, BuildFnTy &MatchInfo);
+
   bool matchPtrAddImmedChain(MachineInstr &MI, PtrAddChain &MatchInfo);
   void applyPtrAddImmedChain(MachineInstr &MI, PtrAddChain &MatchInfo);
 
@@ -1017,6 +1024,9 @@ class CombinerHelper {
   bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo);
 
   bool isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const;
+
+  bool constantFoldICmp(const GICmp &ICmp, const GIConstant &LHS,
+                        const GIConstant &RHS, BuildFnTy &MatchInfo);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index ef1171d9f1f64d..427b5a86b6e0c4 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -950,6 +950,30 @@ class GExtOrTruncOp : public GCastOp {
   };
 };
 
+/// Represents a splat vector.
+class GSplatVector : public GenericMachineInstr {
+public:
+  Register getValueReg() const { return getOperand(1).getReg(); }
+
+  static bool classof(const MachineInstr *MI) {
+    return MI->getOpcode() == TargetOpcode::G_SPLAT_VECTOR;
+  };
+};
+
+/// Represents an integer-like extending operation.
+class GZextOrSextOp : public GCastOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_SEXT:
+    case TargetOpcode::G_ZEXT:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index cf5fd6d6f288bd..a8bf2e722881ac 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -593,5 +593,31 @@ bool isGuaranteedNotToBeUndef(Register Reg, const MachineRegisterInfo &MRI,
 /// estimate of the type.
 Type *getTypeForLLT(LLT Ty, LLVMContext &C);
 
+enum class GIConstantKind { Scalar, FixedVector, ScalableVector };
+
+/// An integer-like constant.
+class GIConstant {
+  GIConstantKind Kind;
+  SmallVector<APInt> Values;
+  APInt Value;
+
+public:
+  GIConstant(ArrayRef<APInt> Values)
+      : Kind(GIConstantKind::FixedVector), Values(Values) {};
+  GIConstant(const APInt &Value, GIConstantKind Kind)
+      : Kind(Kind), Value(Value) {};
+
+  GIConstantKind getKind() const { return Kind; }
+
+  APInt getScalarValue() const;
+
+  static std::optional<GIConstant> getConstant(Register Const,
+                                               const MachineRegisterInfo &MRI);
+};
+
+/// Return true if the given value is known to be non-zero when defined.
+bool isKnownNonZero(Register Reg, const MachineRegisterInfo &MRI,
+                    GISelKnownBits *KB, unsigned Depth = 0);
+
 } // End namespace llvm.
 #endif
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 525cc815e73cef..175a8ed57b2669 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1007,9 +1007,6 @@ def double_icmp_zero_or_combine: GICombineRule<
          (G_ICMP $root, $p, $ordst, 0))
 >;
 
-def double_icmp_zero_and_or_combine : GICombineGroup<[double_icmp_zero_and_combine,
-                                                      double_icmp_zero_or_combine]>;
-
 def and_or_disjoint_mask : GICombineRule<
   (defs root:$root, build_fn_matchinfo:$info),
   (match (wip_match_opcode G_AND):$root,
@@ -1884,6 +1881,46 @@ def cast_combines: GICombineGroup<[
   buildvector_of_truncate
 ]>;
 
+def visit_icmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_ICMP $root, $pred, $lhs, $rhs):$cmp,
+         [{ return Helper.visitICmp(*${cmp}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${cmp}, ${matchinfo}); }])>;
+
+def sext_icmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_SEXT $rhs, $inputR),
+         (G_SEXT $lhs, $inputL),
+         (G_ICMP $root, $pred, $lhs, $rhs):$cmp,
+         [{ return Helper.matchSextOfICmp(*${cmp}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${cmp}, ${matchinfo}); }])>;
+
+def zext_icmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_ZEXT $rhs, $inputR),
+         (G_ZEXT $lhs, $inputL),
+         (G_ICMP $root, $pred, $lhs, $rhs):$cmp,
+         [{ return Helper.matchZextOfICmp(*${cmp}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${cmp}, ${matchinfo}); }])>;
+
+def icmp_of_zero : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_CONSTANT $zero, 0),
+         (G_ICMP $root, $pred, $lhs, $zero):$cmp,
+         [{ return Helper.matchCmpOfZero(*${cmp}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${cmp}, ${matchinfo}); }])>;
+
+def icmp_combines: GICombineGroup<[
+  visit_icmp,
+  sext_icmp,
+  zext_icmp,
+  icmp_of_zero,
+  icmp_to_true_false_known_bits,
+  icmp_to_lhs_known_bits,
+  double_icmp_zero_and_combine,
+  double_icmp_zero_or_combine,
+  redundant_binop_in_equality
+]>;
 
 // FIXME: These should use the custom predicate feature once it lands.
 def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
@@ -1917,7 +1954,7 @@ def const_combines : GICombineGroup<[constant_fold_fp_ops, const_ptradd_to_i2p,
 
 def known_bits_simplifications : GICombineGroup<[
   redundant_and, redundant_sext_inreg, redundant_or, urem_pow2_to_mask,
-  zext_trunc_fold, icmp_to_true_false_known_bits, icmp_to_lhs_known_bits,
+  zext_trunc_fold,
   sext_inreg_to_zext_inreg]>;
 
 def width_reduction_combines : GICombineGroup<[reduce_shl_of_extend,
@@ -1944,7 +1981,7 @@ def constant_fold_binops : GICombineGroup<[constant_fold_binop,
 
 def prefer_sign_combines : GICombineGroup<[nneg_zext]>;
 
-def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
+def all_combines : GICombineGroup<[icmp_combines, integer_reassoc_combines, trivial_combines,
     vector_ops_combines, freeze_combines, cast_combines,
     insert_vec_elt_combines, extract_vec_elt_combines, combines_for_extload,
     combine_extracted_vector_load,
@@ -1964,9 +2001,9 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
     constant_fold_cast_op, fabs_fneg_fold,
     intdiv_combines, mulh_combines, redundant_neg_operands,
     and_or_disjoint_mask, fma_combines, fold_binop_into_select,
-    sub_add_reg, select_to_minmax, redundant_binop_in_equality,
+    sub_add_reg, select_to_minmax,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
-    combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
+    combine_concat_vector, match_addos,
     sext_trunc, zext_trunc, prefer_sign_combines, combine_shuffle_concat]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
diff --git a/llvm/lib/CodeGen/GlobalISel/CMakeLists.txt b/llvm/lib/CodeGen/GlobalISel/CMakeLists.txt
index a15b76440364b1..af1717dbf76f39 100644
--- a/llvm/lib/CodeGen/GlobalISel/CMakeLists.txt
+++ b/llvm/lib/CodeGen/GlobalISel/CMakeLists.txt
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMGlobalISel
   Combiner.cpp
   CombinerHelper.cpp
   CombinerHelperCasts.cpp
+  CombinerHelperCompares.cpp
   CombinerHelperVectorOps.cpp
   GIMatchTableExecutor.cpp
   GISelChangeObserver.cpp
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperCompares.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperCompares.cpp
new file mode 100644
index 00000000000000..415768fb07e59f
--- /dev/null
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperCompares.cpp
@@ -0,0 +1,305 @@
+//===- CombinerHelperCompares.cpp------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements CombinerHelper for G_ICMP
+//
+//===----------------------------------------------------------------------===//
+#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
+#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
+#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
+#include "llvm/CodeGen/LowLevelTypeUtils.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cstdlib>
+
+#define DEBUG_TYPE "gi-combiner"
+
+using namespace llvm;
+
+bool CombinerHelper::constantFoldICmp(const GICmp &ICmp,
+                                      const GIConstant &LHSCst,
+                                      const GIConstant &RHSCst,
+                                      BuildFnTy &MatchInfo) {
+  if (LHSCst.getKind() != GIConstantKind::Scalar)
+    return false;
+
+  Register Dst = ICmp.getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+
+  if (!isConstantLegalOrBeforeLegalizer(DstTy))
+    return false;
+
+  CmpInst::Predicate Pred = ICmp.getCond();
+  APInt LHS = LHSCst.getScalarValue();
+  APInt RHS = RHSCst.getScalarValue();
+
+  bool Result;
+
+  switch (Pred) {
+  case CmpInst::Predicate::ICMP_EQ:
+    Result = LHS.eq(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_NE:
+    Result = LHS.ne(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_UGT:
+    Result = LHS.ugt(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_UGE:
+    Result = LHS.uge(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_ULT:
+    Result = LHS.ult(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_ULE:
+    Result = LHS.ule(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_SGT:
+    Result = LHS.sgt(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_SGE:
+    Result = LHS.sge(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_SLT:
+    Result = LHS.slt(RHS);
+    break;
+  case CmpInst::Predicate::ICMP_SLE:
+    Result = LHS.sle(RHS);
+    break;
+  default:
+    llvm_unreachable("Unexpected predicate");
+  }
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    if (Result)
+      B.buildConstant(Dst, getICmpTrueVal(getTargetLowering(),
+                                          /*IsVector=*/DstTy.isVector(),
+                                          /*IsFP=*/false));
+    else
+      B.buildConstant(Dst, 0);
+  };
+
+  return true;
+}
+
+bool CombinerHelper::visitICmp(const MachineInstr &MI, BuildFnTy &MatchInfo) {
+  const GICmp *Cmp = cast<GICmp>(&MI);
+
+  Register Dst = Cmp->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  Register LHS = Cmp->getLHSReg();
+  Register RHS = Cmp->getRHSReg();
+
+  CmpInst::Predicate Pred = Cmp->getCond();
+  assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!");
+  if (auto CLHS = GIConstant::getConstant(LHS, MRI)) {
+    if (auto CRHS = GIConstant::getConstant(RHS, MRI))
+      return constantFoldICmp(*Cmp, *CLHS, *CRHS, MatchInfo);
+
+    // If we have a constant, make sure it is on the RHS.
+    std::swap(LHS, RHS);
+    Pred = CmpInst::getSwappedPredicate(Pred);
+
+    MatchInfo = [=](MachineIRBuilder &B) { B.buildICmp(Pred, Dst, LHS, RHS); };
+    return true;
+  }
+
+  [[maybe_unused]] MachineInstr *MILHS = MRI.getVRegDef(LHS);
+  MachineInstr *MIRHS = MRI.getVRegDef(RHS);
+
+  // For EQ and NE, we can always pick a value for the undef to make the
+  // predicate pass or fail, so we can return undef.
+  // Matches behavior in llvm::ConstantFoldCompareInstruction.
+  if (isa<GImplicitDef>(MIRHS) && ICmpInst::isEquality(Pred) &&
+      isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
+    MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
+    return true;
+  }
+
+  // icmp X, X -> true/false
+  // icmp X, undef -> true/false because undef could be X.
+  if ((LHS == RHS || isa<GImplicitDef>(MIRHS)) &&
+      isConstantLegalOrBeforeLegalizer(DstTy)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      if (CmpInst::isTrueWhenEqual(Pred))
+        B.buildConstant(Dst, getICmpTrueVal(getTargetLowering(),
+                                            /*IsVector=*/DstTy.isVector(),
+                                            /*IsFP=*/false));
+      else
+        B.buildConstant(Dst, 0);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+bool CombinerHelper::matchSextOfICmp(const MachineInstr &MI,
+                                     BuildFnTy &MatchInfo) {
+  const GICmp *Cmp = cast<GICmp>(&MI);
+
+  Register Dst = Cmp->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  Register LHS = Cmp->getLHSReg();
+  Register RHS = Cmp->getRHSReg();
+  CmpInst::Predicate Pred = Cmp->getCond();
+
+  GSext *SL = cast<GSext>(MRI.getVRegDef(LHS));
+  GSext *SR = cast<GSext>(MRI.getVRegDef(RHS));
+
+  LLT SLTy = MRI.getType(SL->getSrcReg());
+  LLT SRTy = MRI.getType(SR->getSrcReg());
+
+  // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the
+  // same type.
+  if (SLTy != SRTy)
+    return false;
+
+  if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ICMP, {DstTy, SLTy}}))
+    return false;
+
+  // Compare X and Y. Note that the predicate does not change.
+  MatchInfo = [=](MachineIRBuilder &B) {
+    B.buildICmp(Pred, Dst, SL->getSrcReg(), SR->getSrcReg());
+  };
+  return true;
+}
+
+bool CombinerHelper::matchZextOfICmp(const MachineInstr &MI,
+                                     BuildFnTy &MatchInfo) {
+  const GICmp *Cmp = cast<GICmp>(&MI);
+
+  Register Dst = Cmp->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  Register LHS = Cmp->getLHSReg();
+  Register RHS = Cmp->getRHSReg();
+  CmpInst::Predicate Pred = Cmp->getCond();
+
+  /*
+    %x:_(p0) = COPY $x0
+    %y:_(p0) = COPY $x1
+    %zero:_(p0) = G_CONSTANT i64 0
+    %cmp1:_(s1) = G_ICMP intpred(eq), %x:_(p0), %zero:_
+   */
+
+  if (MRI.getType(LHS).isPointer() || MRI.getType(RHS).isPointer())
+    return false;
+
+  if (!MRI.getType(LHS).isScalar() || !MRI.getType(RHS).isScalar())
+    return false;
+
+  GZext *ZL = cast<GZext>(MRI.getVRegDef(LHS));
+  GZext *ZR = cast<GZext>(MRI.getVRegDef(RHS));
+
+  LLT ZLTy = MRI.getType(ZL->getSrcReg());
+  LLT ZRTy = MRI.getType(ZR->getSrcReg());
+
+  // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have
+  // the same type.
+  if (ZLTy != ZRTy)
+    return false;
+
+  if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ICMP, {DstTy, ZLTy}}))
+    return false;
+
+  // Compare X and Y. Note that signed predicates become unsigned.
+  MatchInfo = [=](MachineIRBuilder &B) {
+    B.buildICmp(ICmpInst::getUnsignedPredicate(Pred), Dst, ZL->getSrcReg(),
+                ZR->getSrcReg());
+  };
+  return true;
+}
+
+bool CombinerHelper::matchCmpOfZero(const MachineInstr &MI,
+                                    BuildFnTy &MatchInfo) {
+  const GICmp *Cmp = cast<GICmp>(&MI);
+
+  Register Dst = Cmp->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  Register LHS = Cmp->getLHSReg();
+  CmpInst::Predicate Pred = Cmp->getCond();
+
+  if (!isConstantLegalOrBeforeLegalizer(DstTy))
+    return false;
+
+  std::optional<bool> Result;
+
+  switch (Pred) {
+  default:
+    llvm_unreachable("Unkonwn ICmp predicate!");
+  case ICmpInst::ICMP_ULT:
+    Result = false;
+    break;
+  case ICmpInst::ICMP_UGE:
+    Result = true;
+    break;
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_ULE:
+    if (isKnownNonZero(LHS, MRI, KB))
+      Result = false;
+    break;
+  case ICmpInst::ICMP_NE:
+  case ICmpInst::ICMP_UGT:
+    if (isKnownNonZero(LHS, MRI, KB))
+      Result = true;
+    break;
+  case ICmpInst::ICMP_SLT: {
+    KnownBits LHSKnown = KB->getKnownBits(LHS);
+    if (LHSKnown.isNegative())
+      Result = true;
+    if (LHSKnown.isNonNegative())
+      Result = false;
+    break;
+  }
+  case ICmpInst::ICMP_SLE: {
+    KnownBits LHSKnown = KB->getKnownBits(LHS);
+    if (LHSKnown.isNegative())
+      Result = true;
+    if (LHSKnown.isNonNegative() && isKnownNonZero(LHS, MRI, KB))
+      Result = false;
+    break;
+  }
+  case ICmpInst::ICMP_SGE: {
+    KnownBits LHSKnown = KB->getKnownBits(LHS);
+    if (LHSKnown.isNegative())
+      Result = false;
+    if (LHSKnown.isNonNegative())
+      Result = true;
+    break;
+  }
+  case ICmpInst::ICMP_SGT: {
+    KnownBits LHSKnown = KB->getKnownBits(LHS);
+    if (LHSKnown.isNegative())
+      Result = false;
+    if (LHSKnown.isNonNegative() && isKnownNonZero(LHS, MRI, KB))
+      Result = true;
+    break;
+  }
+  }
+
+  if (!Result)
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    if (*Result)
+      B.buildConstant(Dst, getICmpTrueVal(getTargetLowering(),
+                                          /*IsVector=*/DstTy.isVector(),
+                                          /*IsFP=*/false));
+    else
+      B.buildConstant(Dst, 0);
+  };
+
+  return true;
+}
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index cfdd9905c16fa6..e8b9d995a22768 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1984,3 +1984,326 @@ Type *llvm::getTypeForLLT(LLT Ty, LLVMContext &C) {
                            Ty.getElementCount());
   return IntegerType::get(C, Ty.getSizeInBits());
 }
+
+APInt llvm::GIConstant::getScalarValue() const {
+  assert(Kind == GIConstantKind::Scalar && "Expected scalar constant");
+
+  return Value;
+}
+
+std::optional<GIConstant>
+llvm::GIConstant::getConstant(Register Const, const MachineRegisterInfo &MRI) {
+  MachineInstr *Constant = getDefIgnoringCopies(Const, MRI);
+
+  if (GSplatVector *Splat = dyn_cast<GSplatVector>(Constant)) {
+    std::optional<ValueAndVReg> MayBeConstant =
+        getIConstantVRegValWithLookThrough(Splat->getValueReg(), MRI);
+    if (!MayBeConstant)
+      return std::nullopt;
+    return GIConstant(MayBeConstant->Value, GIConstantKind::ScalableVector);
+  }
+
+  if (GBuildVector *Build = dyn_cast<GBuildVector>(Constant)) {
+    SmallVector<APInt> Values;
+    unsigned NumSources = Build->getNumSources();
+    for (unsigned I = 0; I < NumSources; ++I) {
+      Register SrcReg = Build->getSourceReg(I);
+      std::optional<ValueAndVReg> MayBeConstant =
+          getIConstantVRegValWithLookThrough(SrcReg, MRI);
+      if (!MayBeConstant)
+        return std::nullopt;
+      Values.push_back(MayBeConstant->Value);
+    }
+    return GIConstant(Values);
+  }
+
+  std::optional<ValueAndVReg> MayBeConstant =
+      getIConstantVRegValWithLookThrough(Const, MRI);
+  if (!MayBeConstant)
+    return std::nullopt;
+
+  return GIConstant(MayBeConstant->Value, GIConstantKind::Scalar);
+}
+
+static bool isKnownNonZero(Register Reg, const MachineRegisterInfo &MRI,
+                           GISelKnownBits *KB, unsigned Depth);
+
+bool llvm::isKnownNonZero(Register Reg, const MachineRegisterInfo &MRI,
+                          GISelKnownBits *KB, unsigned Depth) {
+  if (!Reg.isVirtual())
+    return false;
+
+  LLT Ty = MRI.getType(Reg);
+  if (!Ty.isValid())
+    return false;
+
+  if (Ty.isPointer())
+    return false;
+
+  if (!Ty.isScalar())
+    errs() << "type: " << Ty << '\n';
+
+  assert(...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/105991


More information about the llvm-commits mailing list