[llvm] [GlobalISel] Add constant matcher for APInt (PR #151357)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 1 07:05:34 PDT 2025


https://github.com/jyli0116 updated https://github.com/llvm/llvm-project/pull/151357

>From 66454feb05d9f498f1f8fb49c35a99a42808cfb7 Mon Sep 17 00:00:00 2001
From: Yu Li <yu.li at arm.com>
Date: Wed, 30 Jul 2025 15:36:42 +0000
Subject: [PATCH 1/2] [GlobalISel] Add constant matcher for APInt

---
 .../llvm/CodeGen/GlobalISel/MIPatternMatch.h  | 67 +++++++++++++------
 llvm/include/llvm/CodeGen/GlobalISel/Utils.h  | 12 ++++
 llvm/lib/CodeGen/GlobalISel/Utils.cpp         | 22 ++++++
 .../CodeGen/GlobalISel/PatternMatchTest.cpp   | 40 +++++++++++
 4 files changed, 122 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index c0d3a12cbcb41..e8d9bc03f6428 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -192,24 +192,35 @@ m_GFCstOrSplat(std::optional<FPValueAndVReg> &FPValReg) {
 
 /// Matcher for a specific constant value.
 struct SpecificConstantMatch {
-  int64_t RequestedVal;
-  SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {}
+  APInt RequestedVal;
+  SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
-    int64_t MatchedVal;
-    return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal;
+    APInt MatchedVal;
+    if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
+      if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
+        RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
+      else
+        MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
+
+      return APInt::isSameValue(MatchedVal, RequestedVal);
+    }
+    return false;
   }
 };
 
 /// Matches a constant equal to \p RequestedValue.
+inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) {
+  return SpecificConstantMatch(std::move(RequestedValue));
+}
+
 inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
-  return SpecificConstantMatch(RequestedValue);
+  return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true));
 }
 
 /// Matcher for a specific constant splat.
 struct SpecificConstantSplatMatch {
-  int64_t RequestedVal;
-  SpecificConstantSplatMatch(int64_t RequestedVal)
-      : RequestedVal(RequestedVal) {}
+  APInt RequestedVal;
+  SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
     return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
                                       /* AllowUndef */ false);
@@ -217,19 +228,31 @@ struct SpecificConstantSplatMatch {
 };
 
 /// Matches a constant splat of \p RequestedValue.
+inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) {
+  return SpecificConstantSplatMatch(std::move(RequestedValue));
+}
+
 inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
-  return SpecificConstantSplatMatch(RequestedValue);
+  return SpecificConstantSplatMatch(
+      APInt(64, RequestedValue, /* isSigned */ true));
 }
 
 /// Matcher for a specific constant or constant splat.
 struct SpecificConstantOrSplatMatch {
-  int64_t RequestedVal;
-  SpecificConstantOrSplatMatch(int64_t RequestedVal)
+  APInt RequestedVal;
+  SpecificConstantOrSplatMatch(APInt RequestedVal)
       : RequestedVal(RequestedVal) {}
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
-    int64_t MatchedVal;
-    if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
-      return true;
+    APInt MatchedVal;
+    if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
+      if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
+        RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
+      else
+        MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
+
+      if (APInt::isSameValue(MatchedVal, RequestedVal))
+        return true;
+    }
     return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
                                       /* AllowUndef */ false);
   }
@@ -237,18 +260,24 @@ struct SpecificConstantOrSplatMatch {
 
 /// Matches a \p RequestedValue constant or a constant splat of \p
 /// RequestedValue.
+inline SpecificConstantOrSplatMatch
+m_SpecificICstOrSplat(APInt RequestedValue) {
+  return SpecificConstantOrSplatMatch(std::move(RequestedValue));
+}
+
 inline SpecificConstantOrSplatMatch
 m_SpecificICstOrSplat(int64_t RequestedValue) {
-  return SpecificConstantOrSplatMatch(RequestedValue);
+  return SpecificConstantOrSplatMatch(
+      APInt(64, RequestedValue, /* isSigned */ true));
 }
 
-///{
 /// Convenience matchers for specific integer values.
-inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }
+inline SpecificConstantMatch m_ZeroInt() {
+  return SpecificConstantMatch(APInt(64, 0));
+}
 inline SpecificConstantMatch m_AllOnesInt() {
-  return SpecificConstantMatch(-1);
+  return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true));
 }
-///}
 
 /// Matcher for a specific register.
 struct SpecificRegisterMatch {
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 66c960fe12c68..5c27605c26883 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
                                          const MachineRegisterInfo &MRI,
                                          int64_t SplatValue, bool AllowUndef);
 
+/// Return true if the specified register is defined by G_BUILD_VECTOR or
+/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
+LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
+                                         const MachineRegisterInfo &MRI,
+                                         APInt SplatValue, bool AllowUndef);
+
 /// Return true if the specified instruction is a G_BUILD_VECTOR or
 /// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
 LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
                                          const MachineRegisterInfo &MRI,
                                          int64_t SplatValue, bool AllowUndef);
 
+/// Return true if the specified instruction is a G_BUILD_VECTOR or
+/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
+LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
+                                         const MachineRegisterInfo &MRI,
+                                         APInt SplatValue, bool AllowUndef);
+
 /// Return true if the specified instruction is a G_BUILD_VECTOR or
 /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
 LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI,
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index f48bfc06c14be..8955dd0370539 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1401,6 +1401,21 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg,
   return false;
 }
 
+bool llvm::isBuildVectorConstantSplat(const Register Reg,
+                                      const MachineRegisterInfo &MRI,
+                                      APInt SplatValue, bool AllowUndef) {
+  if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) {
+    if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth())
+      return APInt::isSameValue(
+          SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue);
+    return APInt::isSameValue(
+        SplatValAndReg->Value,
+        SplatValue.sext(SplatValAndReg->Value.getBitWidth()));
+  }
+
+  return false;
+}
+
 bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
                                       const MachineRegisterInfo &MRI,
                                       int64_t SplatValue, bool AllowUndef) {
@@ -1408,6 +1423,13 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
                                     AllowUndef);
 }
 
+bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
+                                      const MachineRegisterInfo &MRI,
+                                      APInt SplatValue, bool AllowUndef) {
+  return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
+                                    AllowUndef);
+}
+
 std::optional<APInt>
 llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) {
   if (auto SplatValAndReg =
diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index 25eb67e981588..1e0653b61e8f8 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) {
   auto FortyTwo = B.buildConstant(LLT::scalar(64), 42);
   EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42)));
   EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123)));
+  EXPECT_TRUE(
+      mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42))));
+  EXPECT_FALSE(
+      mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123))));
 
   // Test that this works inside of a more complex pattern.
   LLT s64 = LLT::scalar(64);
   auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo);
   EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42)));
+  EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42))));
 
   // Wrong constant.
   EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123)));
+  EXPECT_FALSE(
+      mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123))));
 
   // No constant on the LHS.
   EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
+  EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42))));
 }
 
 TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
@@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
       mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
   EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));
 
+  EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+                       m_SpecificICstSplat(APInt(64, 42))));
+  EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+                        m_SpecificICstSplat(APInt(64, 43))));
+  EXPECT_FALSE(
+      mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42))));
+
   MachineInstrBuilder NonConstantSplat =
       B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
 
@@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
   EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
   EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));
 
+  EXPECT_TRUE(
+      mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
+  EXPECT_FALSE(
+      mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43))));
+  EXPECT_FALSE(
+      mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42))));
+
   MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
   EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
+  EXPECT_FALSE(
+      mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
 }
 
 TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
@@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
       mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
   EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
 
+  EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+                       m_SpecificICstOrSplat(APInt(64, 42))));
+  EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+                        m_SpecificICstOrSplat(APInt(64, 43))));
+  EXPECT_TRUE(
+      mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
+
   MachineInstrBuilder NonConstantSplat =
       B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
 
@@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
   EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
   EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));
 
+  EXPECT_TRUE(
+      mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
+  EXPECT_FALSE(
+      mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43))));
+  EXPECT_FALSE(
+      mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
+
   MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
   EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
+  EXPECT_TRUE(
+      mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
 }
 
 TEST_F(AArch64GISelMITest, MatchZeroInt) {

>From 57005c01a39dce83232e264dc45de5099563da93 Mon Sep 17 00:00:00 2001
From: Yu Li <yu.li at arm.com>
Date: Fri, 1 Aug 2025 14:05:00 +0000
Subject: [PATCH 2/2] Address Comments

---
 llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index e8d9bc03f6428..66f0a61ec316a 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -193,7 +193,7 @@ m_GFCstOrSplat(std::optional<FPValueAndVReg> &FPValReg) {
 /// Matcher for a specific constant value.
 struct SpecificConstantMatch {
   APInt RequestedVal;
-  SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
+  SpecificConstantMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {}
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
     APInt MatchedVal;
     if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
@@ -220,7 +220,7 @@ inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
 /// Matcher for a specific constant splat.
 struct SpecificConstantSplatMatch {
   APInt RequestedVal;
-  SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
+  SpecificConstantSplatMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {}
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
     return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
                                       /* AllowUndef */ false);
@@ -240,7 +240,7 @@ inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
 /// Matcher for a specific constant or constant splat.
 struct SpecificConstantOrSplatMatch {
   APInt RequestedVal;
-  SpecificConstantOrSplatMatch(APInt RequestedVal)
+  SpecificConstantOrSplatMatch(const APInt RequestedVal)
       : RequestedVal(RequestedVal) {}
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
     APInt MatchedVal;
@@ -273,10 +273,10 @@ m_SpecificICstOrSplat(int64_t RequestedValue) {
 
 /// Convenience matchers for specific integer values.
 inline SpecificConstantMatch m_ZeroInt() {
-  return SpecificConstantMatch(APInt(64, 0));
+  return SpecificConstantMatch(APInt::getZero(64));
 }
 inline SpecificConstantMatch m_AllOnesInt() {
-  return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true));
+  return SpecificConstantMatch(APInt::getAllOnes(64));
 }
 
 /// Matcher for a specific register.



More information about the llvm-commits mailing list