[llvm] [GlobalISel] Fold G_ICMP if possible (PR #86357)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 22 16:22:22 PDT 2024


https://github.com/shiltian created https://github.com/llvm/llvm-project/pull/86357

This patch tries to fold `G_ICMP` if possible.


>From 10c510724da8d8dcb5b9e7b74fed9367d927ba55 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Fri, 22 Mar 2024 19:21:08 -0400
Subject: [PATCH] [GlobalISel] Fold G_ICMP if possible

This patch tries to fold `G_ICMP` if possible.
---
 llvm/include/llvm/CodeGen/GlobalISel/Utils.h  |  4 ++
 llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp | 14 ++++
 llvm/lib/CodeGen/GlobalISel/Utils.cpp         | 68 +++++++++++++++++++
 3 files changed, 86 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index f8900f3434ccaa..fc4ebf795a334e 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -313,6 +313,10 @@ std::optional<APFloat> ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy,
 std::optional<SmallVector<unsigned>>
 ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI);
 
+std::optional<SmallVector<APInt>>
+ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+                 const MachineRegisterInfo &MRI);
+
 /// Test if the given value is known to have exactly one bit set. This differs
 /// from computeKnownBits in that it doesn't necessarily determine which bit is
 /// set.
diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index 1869e0d41a51f6..44e1be2a5e080d 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -174,6 +174,20 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
   switch (Opc) {
   default:
     break;
+  case TargetOpcode::G_ICMP: {
+    assert(SrcOps.size() == 3 && "Invalid sources");
+    assert(DstOps.size() == 1 && "Invalid dsts");
+    LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
+
+    if (std::optional<SmallVector<APInt>> Cst =
+            ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
+                             SrcOps[2].getReg(), *getMRI())) {
+      if (SrcTy.isVector())
+        return buildBuildVectorConstant(DstOps[0], *Cst);
+      return buildConstant(DstOps[0], Cst->front());
+    }
+    break;
+  }
   case TargetOpcode::G_ADD:
   case TargetOpcode::G_PTR_ADD:
   case TargetOpcode::G_AND:
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index a9fa73b60a097f..133580f16bdfa7 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -996,6 +996,74 @@ llvm::ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI) {
   return std::nullopt;
 }
 
+std::optional<SmallVector<APInt>>
+llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+                       const MachineRegisterInfo &MRI) {
+  LLT Ty = MRI.getType(Op1);
+  if (Ty != MRI.getType(Op2))
+    return std::nullopt;
+
+  auto TryFoldScalar = [&MRI, Pred](Register LHS,
+                                    Register RHS) -> std::optional<APInt> {
+    auto LHSCst = getIConstantVRegVal(LHS, MRI);
+    auto RHSCst = getIConstantVRegVal(RHS, MRI);
+    if (!LHSCst || !RHSCst)
+      return std::nullopt;
+
+    switch (Pred) {
+    case CmpInst::Predicate::ICMP_EQ:
+      return APInt(1, LHSCst->eq(*RHSCst));
+    case CmpInst::Predicate::ICMP_NE:
+      return APInt(1, LHSCst->ne(*RHSCst));
+    case CmpInst::Predicate::ICMP_UGT:
+      return APInt(1, LHSCst->ugt(*RHSCst));
+    case CmpInst::Predicate::ICMP_UGE:
+      return APInt(1, LHSCst->uge(*RHSCst));
+    case CmpInst::Predicate::ICMP_ULT:
+      return APInt(1, LHSCst->ult(*RHSCst));
+    case CmpInst::Predicate::ICMP_ULE:
+      return APInt(1, LHSCst->ult(*RHSCst));
+    case CmpInst::Predicate::ICMP_SGT:
+      return APInt(1, LHSCst->sgt(*RHSCst));
+    case CmpInst::Predicate::ICMP_SGE:
+      return APInt(1, LHSCst->sge(*RHSCst));
+    case CmpInst::Predicate::ICMP_SLT:
+      return APInt(1, LHSCst->slt(*RHSCst));
+    case CmpInst::Predicate::ICMP_SLE:
+      return APInt(1, LHSCst->sle(*RHSCst));
+    default:
+      return std::nullopt;
+    }
+  };
+
+  SmallVector<APInt> FoldedICmps;
+
+  if (Ty.isVector()) {
+    // Try to constant fold each element.
+    auto *BV1 = getOpcodeDef<GBuildVector>(Op1, MRI);
+    auto *BV2 = getOpcodeDef<GBuildVector>(Op2, MRI);
+    if (!BV1 || !BV2)
+      return std::nullopt;
+    assert(BV1->getNumSources() == BV2->getNumSources() && "Invalid vectors");
+    for (unsigned I = 0; I < BV1->getNumSources(); ++I) {
+      if (auto MaybeFold =
+              TryFoldScalar(BV1->getSourceReg(I), BV2->getSourceReg(I))) {
+        FoldedICmps.emplace_back(*MaybeFold);
+        continue;
+      }
+      return std::nullopt;
+    }
+    return FoldedICmps;
+  }
+
+  if (auto MaybeCst = TryFoldScalar(Op1, Op2)) {
+    FoldedICmps.emplace_back(*MaybeCst);
+    return FoldedICmps;
+  }
+
+  return std::nullopt;
+}
+
 bool llvm::isKnownToBeAPowerOfTwo(Register Reg, const MachineRegisterInfo &MRI,
                                   GISelKnownBits *KB) {
   std::optional<DefinitionAndSourceRegister> DefSrcReg =



More information about the llvm-commits mailing list