[llvm] 485dd0b - [GlobalISel] Handle constant splat in funnel shift combine

Abinav Puthan Purayil via llvm-commits llvm-commits at lists.llvm.org
Mon May 16 03:38:25 PDT 2022


Author: Abinav Puthan Purayil
Date: 2022-05-16T16:03:30+05:30
New Revision: 485dd0b752cd78c93b1c41e922a73c07b565a9f0

URL: https://github.com/llvm/llvm-project/commit/485dd0b752cd78c93b1c41e922a73c07b565a9f0
DIFF: https://github.com/llvm/llvm-project/commit/485dd0b752cd78c93b1c41e922a73c07b565a9f0.diff

LOG: [GlobalISel] Handle constant splat in funnel shift combine

This change adds the constant splat versions of m_ICst() (by using
getBuildVectorConstantSplat()) and uses it in
matchOrShiftToFunnelShift(). The getBuildVectorConstantSplat() name is
shortened to getIConstantSplatVal() so that the *SExtVal() version would
have a more compact name.

Differential Revision: https://reviews.llvm.org/D125516

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
    llvm/include/llvm/CodeGen/GlobalISel/Utils.h
    llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
    llvm/lib/CodeGen/GlobalISel/Utils.cpp
    llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
    llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
    llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index daf1ff052983f..1cacf96620f02 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -94,6 +94,48 @@ inline ConstantMatch<int64_t> m_ICst(int64_t &Cst) {
   return ConstantMatch<int64_t>(Cst);
 }
 
+template <typename ConstT>
+inline Optional<ConstT> matchConstantSplat(Register,
+                                           const MachineRegisterInfo &);
+
+template <>
+inline Optional<APInt> matchConstantSplat(Register Reg,
+                                          const MachineRegisterInfo &MRI) {
+  return getIConstantSplatVal(Reg, MRI);
+}
+
+template <>
+inline Optional<int64_t> matchConstantSplat(Register Reg,
+                                            const MachineRegisterInfo &MRI) {
+  return getIConstantSplatSExtVal(Reg, MRI);
+}
+
+template <typename ConstT> struct ICstOrSplatMatch {
+  ConstT &CR;
+  ICstOrSplatMatch(ConstT &C) : CR(C) {}
+  bool match(const MachineRegisterInfo &MRI, Register Reg) {
+    if (auto MaybeCst = matchConstant<ConstT>(Reg, MRI)) {
+      CR = *MaybeCst;
+      return true;
+    }
+
+    if (auto MaybeCstSplat = matchConstantSplat<ConstT>(Reg, MRI)) {
+      CR = *MaybeCstSplat;
+      return true;
+    }
+
+    return false;
+  };
+};
+
+inline ICstOrSplatMatch<APInt> m_ICstOrSplat(APInt &Cst) {
+  return ICstOrSplatMatch<APInt>(Cst);
+}
+
+inline ICstOrSplatMatch<int64_t> m_ICstOrSplat(int64_t &Cst) {
+  return ICstOrSplatMatch<int64_t>(Cst);
+}
+
 struct GCstAndRegMatch {
   Optional<ValueAndVReg> &ValReg;
   GCstAndRegMatch(Optional<ValueAndVReg> &ValReg) : ValReg(ValReg) {}

diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 7c1c89d8d6a9a..78f1b49da822a 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -373,9 +373,23 @@ class RegOrConstant {
 /// If \p MI is not a splat, returns None.
 Optional<int> getSplatIndex(MachineInstr &MI);
 
-/// Returns a scalar constant of a G_BUILD_VECTOR splat if it exists.
-Optional<int64_t> getBuildVectorConstantSplat(const MachineInstr &MI,
-                                              const MachineRegisterInfo &MRI);
+/// \returns the scalar integral splat value of \p Reg if possible.
+Optional<APInt> getIConstantSplatVal(const Register Reg,
+                                     const MachineRegisterInfo &MRI);
+
+/// \returns the scalar integral splat value defined by \p MI if possible.
+Optional<APInt> getIConstantSplatVal(const MachineInstr &MI,
+                                     const MachineRegisterInfo &MRI);
+
+/// \returns the scalar sign extended integral splat value of \p Reg if
+/// possible.
+Optional<int64_t> getIConstantSplatSExtVal(const Register Reg,
+                                           const MachineRegisterInfo &MRI);
+
+/// \returns the scalar sign extended integral splat value defined by \p MI if
+/// possible.
+Optional<int64_t> getIConstantSplatSExtVal(const MachineInstr &MI,
+                                           const MachineRegisterInfo &MRI);
 
 /// Returns a floating point scalar constant of a build vector splat if it
 /// exists. When \p AllowUndef == true some elements can be undef but not all.

diff  --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index cdd85d4ba84ea..80f556cbaad79 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -2945,7 +2945,7 @@ bool CombinerHelper::matchNotCmp(MachineInstr &MI,
   int64_t Cst;
   if (Ty.isVector()) {
     MachineInstr *CstDef = MRI.getVRegDef(CstReg);
-    auto MaybeCst = getBuildVectorConstantSplat(*CstDef, MRI);
+    auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI);
     if (!MaybeCst)
       return false;
     if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP))
@@ -4029,10 +4029,9 @@ bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
 
   // Given constants C0 and C1 such that C0 + C1 is bit-width:
   // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
-  // TODO: Match constant splat.
   int64_t CstShlAmt, CstLShrAmt;
-  if (mi_match(ShlAmt, MRI, m_ICst(CstShlAmt)) &&
-      mi_match(LShrAmt, MRI, m_ICst(CstLShrAmt)) &&
+  if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) &&
+      mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) &&
       CstShlAmt + CstLShrAmt == BitWidth) {
     FshOpc = TargetOpcode::G_FSHR;
     Amt = LShrAmt;

diff  --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index cf80350f17dea..425a155d262d2 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1071,15 +1071,38 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
                                     AllowUndef);
 }
 
+Optional<APInt> llvm::getIConstantSplatVal(const Register Reg,
+                                           const MachineRegisterInfo &MRI) {
+  if (auto SplatValAndReg =
+          getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false)) {
+    Optional<ValueAndVReg> ValAndVReg =
+        getIConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI);
+    return ValAndVReg->Value;
+  }
+
+  return None;
+}
+
+Optional<APInt> getIConstantSplatVal(const MachineInstr &MI,
+                                     const MachineRegisterInfo &MRI) {
+  return getIConstantSplatVal(MI.getOperand(0).getReg(), MRI);
+}
+
 Optional<int64_t>
-llvm::getBuildVectorConstantSplat(const MachineInstr &MI,
-                                  const MachineRegisterInfo &MRI) {
+llvm::getIConstantSplatSExtVal(const Register Reg,
+                               const MachineRegisterInfo &MRI) {
   if (auto SplatValAndReg =
-          getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false))
+          getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false))
     return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI);
   return None;
 }
 
+Optional<int64_t>
+llvm::getIConstantSplatSExtVal(const MachineInstr &MI,
+                               const MachineRegisterInfo &MRI) {
+  return getIConstantSplatSExtVal(MI.getOperand(0).getReg(), MRI);
+}
+
 Optional<FPValueAndVReg> llvm::getFConstantSplat(Register VReg,
                                                  const MachineRegisterInfo &MRI,
                                                  bool AllowUndef) {
@@ -1105,7 +1128,7 @@ Optional<RegOrConstant> llvm::getVectorSplat(const MachineInstr &MI,
   unsigned Opc = MI.getOpcode();
   if (!isBuildVectorOp(Opc))
     return None;
-  if (auto Splat = getBuildVectorConstantSplat(MI, MRI))
+  if (auto Splat = getIConstantSplatSExtVal(MI, MRI))
     return RegOrConstant(*Splat);
   auto Reg = MI.getOperand(1).getReg();
   if (any_of(make_range(MI.operands_begin() + 2, MI.operands_end()),
@@ -1176,7 +1199,7 @@ llvm::isConstantOrConstantSplatVector(MachineInstr &MI,
   Register Def = MI.getOperand(0).getReg();
   if (auto C = getIConstantVRegValWithLookThrough(Def, MRI))
     return C->Value;
-  auto MaybeCst = getBuildVectorConstantSplat(MI, MRI);
+  auto MaybeCst = getIConstantSplatSExtVal(MI, MRI);
   if (!MaybeCst)
     return None;
   const unsigned ScalarSize = MRI.getType(Def).getScalarSizeInBits();

diff  --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
index 6fdf284d09471..c724e088b2fdb 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
@@ -143,13 +143,9 @@ body: |
     ; CHECK-NEXT: {{  $}}
     ; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
     ; CHECK-NEXT: %b:_(<2 x s32>) = COPY $vgpr2_vgpr3
-    ; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20
-    ; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32)
     ; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12
     ; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32)
-    ; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>)
-    ; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %b, %amt1(<2 x s32>)
-    ; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr
+    ; CHECK-NEXT: %or:_(<2 x s32>) = G_FSHR %a, %b, %amt1(<2 x s32>)
     ; CHECK-NEXT: $vgpr4_vgpr5 = COPY %or(<2 x s32>)
     %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
     %b:_(<2 x s32>) = COPY $vgpr2_vgpr3

diff  --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
index bdd860bb9d63b..864f937c41c48 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
@@ -132,13 +132,9 @@ body: |
     ; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3
     ; CHECK-NEXT: {{  $}}
     ; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
-    ; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20
-    ; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32)
     ; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12
     ; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32)
-    ; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>)
-    ; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %a, %amt1(<2 x s32>)
-    ; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr
+    ; CHECK-NEXT: %or:_(<2 x s32>) = G_ROTR %a, %amt1(<2 x s32>)
     ; CHECK-NEXT: $vgpr2_vgpr3 = COPY %or(<2 x s32>)
     %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
     %scalar_amt0:_(s32) = G_CONSTANT i32 20

diff  --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index da0aee6b46c29..cb8c7b46b6aa2 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -51,6 +51,25 @@ TEST_F(AArch64GISelMITest, MatchIntConstantRegister) {
   EXPECT_EQ(Src0->VReg, MIBCst.getReg(0));
 }
 
+TEST_F(AArch64GISelMITest, MatchIntConstantSplat) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT s64 = LLT::scalar(64);
+  LLT v4s64 = LLT::fixed_vector(4, s64);
+
+  MachineInstrBuilder FortyTwoSplat =
+      B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
+  int64_t Cst;
+  EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, m_ICstOrSplat(Cst)));
+  EXPECT_EQ(Cst, 42);
+
+  MachineInstrBuilder NonConstantSplat =
+      B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
+  EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, m_ICstOrSplat(Cst)));
+}
+
 TEST_F(AArch64GISelMITest, MachineInstrPtrBind) {
   setUp();
   if (!TM)


        


More information about the llvm-commits mailing list