[llvm] [GISel] Combine out-of-range shifts with value to 0 or -1 (PR #123510)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 19 04:36:08 PST 2025


https://github.com/lialan created https://github.com/llvm/llvm-project/pull/123510

This resolves #123212 

>From c2d4c0d031165c2c7729fb860a4205e8c46dba30 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Sun, 19 Jan 2025 12:42:52 +0800
Subject: [PATCH 1/2] First commit

---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  | 28 +++++------
 .../include/llvm/Target/GlobalISel/Combine.td | 13 +++--
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 49 ++++++++++++++++++-
 3 files changed, 70 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 94e36e412b0cf7..da558de5946559 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -115,18 +115,13 @@ class CombinerHelper {
 
 public:
   CombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B,
-                 bool IsPreLegalize,
-                 GISelKnownBits *KB = nullptr,
+                 bool IsPreLegalize, GISelKnownBits *KB = nullptr,
                  MachineDominatorTree *MDT = nullptr,
                  const LegalizerInfo *LI = nullptr);
 
-  GISelKnownBits *getKnownBits() const {
-    return KB;
-  }
+  GISelKnownBits *getKnownBits() const { return KB; }
 
-  MachineIRBuilder &getBuilder() const {
-    return Builder;
-  }
+  MachineIRBuilder &getBuilder() const { return Builder; }
 
   const TargetLowering &getTargetLowering() const;
 
@@ -150,8 +145,10 @@ class CombinerHelper {
   /// is a legal integer constant type on the target.
   bool isConstantLegalOrBeforeLegalizer(const LLT Ty) const;
 
-  /// MachineRegisterInfo::replaceRegWith() and inform the observer of the changes
-  void replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, Register ToReg) const;
+  /// MachineRegisterInfo::replaceRegWith() and inform the observer of the
+  /// changes
+  void replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
+                      Register ToReg) const;
 
   /// Replace a single register operand with a new register and inform the
   /// observer of the changes.
@@ -482,12 +479,12 @@ class CombinerHelper {
   bool matchEqualDefs(const MachineOperand &MOP1,
                       const MachineOperand &MOP2) const;
 
-  /// Return true if \p MOP is defined by a G_CONSTANT or splat with a value equal to
-  /// \p C.
+  /// Return true if \p MOP is defined by a G_CONSTANT or splat with a value
+  /// equal to \p C.
   bool matchConstantOp(const MachineOperand &MOP, int64_t C) const;
 
-  /// Return true if \p MOP is defined by a G_FCONSTANT or splat with a value exactly
-  /// equal to \p C.
+  /// Return true if \p MOP is defined by a G_FCONSTANT or splat with a value
+  /// exactly equal to \p C.
   bool matchConstantFPOp(const MachineOperand &MOP, double C) const;
 
   /// @brief Checks if constant at \p ConstIdx is larger than \p MI 's bitwidth
@@ -841,7 +838,8 @@ class CombinerHelper {
                                      BuildFnTy &MatchInfo) const;
 
   /// Match shifts greater or equal to the bitwidth of the operation.
-  bool matchShiftsTooBig(MachineInstr &MI) const;
+  bool matchShiftsTooBig(MachineInstr &MI,
+                         std::optional<int64_t> &MatchInfo) const;
 
   /// Match constant LHS ops that should be commuted.
   bool matchCommuteConstantToRHS(MachineInstr &MI) const;
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 8641eabbdd84c6..ae0856550d9356 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -306,11 +306,18 @@ def ptr_add_immed_chain : GICombineRule<
          [{ return Helper.matchPtrAddImmedChain(*${d}, ${matchinfo}); }]),
   (apply [{ Helper.applyPtrAddImmedChain(*${d}, ${matchinfo}); }])>;
 
+def shift_result_matchdata : GIDefMatchData<"std::optional<int64_t>">;
 def shifts_too_big : GICombineRule<
-  (defs root:$root),
+  (defs root:$root, shift_result_matchdata:$matchinfo),
   (match (wip_match_opcode G_SHL, G_ASHR, G_LSHR):$root,
-         [{ return Helper.matchShiftsTooBig(*${root}); }]),
-  (apply [{ Helper.replaceInstWithUndef(*${root}); }])>;
+         [{ return Helper.matchShiftsTooBig(*${root}, ${matchinfo}); }]),
+  (apply [{
+    if (${matchinfo}) {
+      Helper.replaceInstWithConstant(*${root}, *${matchinfo});
+    } else {
+      Helper.replaceInstWithUndef(*${root});
+    }
+  }])>;
 
 // Fold shift (shift base x), y -> shift base, (x+y), if shifts are same
 def shift_immed_matchdata : GIDefMatchData<"RegisterImmPair">;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 4e3aaf5da7198c..6c04337ad73e0b 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -35,6 +35,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/DivisionByConstantInfo.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Target/TargetMachine.h"
 #include <cmath>
@@ -6590,12 +6591,56 @@ bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
   return CmpInst::isEquality(Pred) && Y.isValid();
 }
 
-bool CombinerHelper::matchShiftsTooBig(MachineInstr &MI) const {
+static std::optional<unsigned>
+getMaxUsefulShift(KnownBits ValueKB, unsigned Opcode,
+                  std::optional<int64_t> &Result) {
+  assert(Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR ||
+         Opcode == TargetOpcode::G_ASHR && "Expect G_SHL, G_LSHR or G_ASHR.");
+  auto SignificantBits = 0;
+  switch (Opcode) {
+  case TargetOpcode::G_SHL:
+    SignificantBits = ValueKB.countMinTrailingZeros();
+    Result = 0;
+    break;
+  case TargetOpcode::G_LSHR:
+    Result = 0;
+    SignificantBits = ValueKB.countMinLeadingZeros();
+    break;
+  case TargetOpcode::G_ASHR:
+    if (ValueKB.isNonNegative()) {
+      SignificantBits = ValueKB.countMinLeadingZeros();
+      Result = 0;
+    } else if (ValueKB.isNegative()) {
+      SignificantBits = ValueKB.countMinLeadingOnes();
+      Result = -1;
+    } else {
+      // Cannot determine shift result.
+      Result = std::nullopt;
+      return false;
+    }
+    break;
+  default:
+    break;
+  }
+  return ValueKB.getBitWidth() - SignificantBits;
+}
+
+bool CombinerHelper::matchShiftsTooBig(
+    MachineInstr &MI, std::optional<int64_t> &MatchInfo) const {
+  Register ShiftVal = MI.getOperand(1).getReg();
   Register ShiftReg = MI.getOperand(2).getReg();
   LLT ResTy = MRI.getType(MI.getOperand(0).getReg());
   auto IsShiftTooBig = [&](const Constant *C) {
     auto *CI = dyn_cast<ConstantInt>(C);
-    return CI && CI->uge(ResTy.getScalarSizeInBits());
+    if (!CI)
+      return false;
+    if (CI->uge(ResTy.getScalarSizeInBits())) {
+      MatchInfo = std::nullopt;
+      return true;
+    }
+    auto OptMaxUsefulShift = getMaxUsefulShift(KB->getKnownBits(ShiftVal),
+                                               MI.getOpcode(), MatchInfo);
+    return OptMaxUsefulShift && CI->uge(*OptMaxUsefulShift);
   };
   return matchUnaryPredicate(MRI, ShiftReg, IsShiftTooBig);
 }

>From 37b8ac1177a5f9f2122f06cb6911b04f2b35af8e Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Sun, 19 Jan 2025 20:29:57 +0800
Subject: [PATCH 2/2] Fix up

---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  | 28 +++++++++++--------
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  1 -
 2 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index da558de5946559..a51aa876e1deb0 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -115,13 +115,18 @@ class CombinerHelper {
 
 public:
   CombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B,
-                 bool IsPreLegalize, GISelKnownBits *KB = nullptr,
+                 bool IsPreLegalize,
+                 GISelKnownBits *KB = nullptr,
                  MachineDominatorTree *MDT = nullptr,
                  const LegalizerInfo *LI = nullptr);
 
-  GISelKnownBits *getKnownBits() const { return KB; }
+  GISelKnownBits *getKnownBits() const {
+    return KB;
+  }
 
-  MachineIRBuilder &getBuilder() const { return Builder; }
+  MachineIRBuilder &getBuilder() const {
+    return Builder;
+  }
 
   const TargetLowering &getTargetLowering() const;
 
@@ -145,10 +150,8 @@ class CombinerHelper {
   /// is a legal integer constant type on the target.
   bool isConstantLegalOrBeforeLegalizer(const LLT Ty) const;
 
-  /// MachineRegisterInfo::replaceRegWith() and inform the observer of the
-  /// changes
-  void replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
-                      Register ToReg) const;
+  /// MachineRegisterInfo::replaceRegWith() and inform the observer of the changes
+  void replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, Register ToReg) const;
 
   /// Replace a single register operand with a new register and inform the
   /// observer of the changes.
@@ -479,12 +482,12 @@ class CombinerHelper {
   bool matchEqualDefs(const MachineOperand &MOP1,
                       const MachineOperand &MOP2) const;
 
-  /// Return true if \p MOP is defined by a G_CONSTANT or splat with a value
-  /// equal to \p C.
+  /// Return true if \p MOP is defined by a G_CONSTANT or splat with a value equal to
+  /// \p C.
   bool matchConstantOp(const MachineOperand &MOP, int64_t C) const;
 
-  /// Return true if \p MOP is defined by a G_FCONSTANT or splat with a value
-  /// exactly equal to \p C.
+  /// Return true if \p MOP is defined by a G_FCONSTANT or splat with a value exactly
+  /// equal to \p C.
   bool matchConstantFPOp(const MachineOperand &MOP, double C) const;
 
   /// @brief Checks if constant at \p ConstIdx is larger than \p MI 's bitwidth
@@ -837,7 +840,8 @@ class CombinerHelper {
   bool matchRedundantBinOpInEquality(MachineInstr &MI,
                                      BuildFnTy &MatchInfo) const;
 
-  /// Match shifts greater or equal to the bitwidth of the operation.
+  /// Match shifts greater or equal to the range (bitwidth of the operation, or
+  /// the source value).
   bool matchShiftsTooBig(MachineInstr &MI,
                          std::optional<int64_t> &MatchInfo) const;
 
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 6c04337ad73e0b..b4cf1655b04728 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -35,7 +35,6 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/DivisionByConstantInfo.h"
 #include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/KnownBits.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Target/TargetMachine.h"
 #include <cmath>



More information about the llvm-commits mailing list