[llvm] 8bc7185 - GlobalISel/Utils: Refactor constant splat match functions

Petar Avramovic via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 21 03:10:21 PDT 2021


Author: Petar Avramovic
Date: 2021-09-21T12:09:35+02:00
New Revision: 8bc71856681c235a3192813947308a19577c9236

URL: https://github.com/llvm/llvm-project/commit/8bc71856681c235a3192813947308a19577c9236
DIFF: https://github.com/llvm/llvm-project/commit/8bc71856681c235a3192813947308a19577c9236.diff

LOG: GlobalISel/Utils: Refactor constant splat match functions

Add generic helper function that matches constant splat. It has option to
match constant splat with undef (some elements can be undef but not all).
Add util function and matcher for G_FCONSTANT splat.

Differential Revision: https://reviews.llvm.org/D104410

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
    llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
    llvm/include/llvm/CodeGen/GlobalISel/Utils.h
    llvm/lib/CodeGen/GlobalISel/Utils.cpp
    llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 2b0ef6c3af574..bb5f55789a0ec 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -206,6 +206,14 @@ class GPtrAdd : public GenericMachineInstr {
   }
 };
 
+/// Represents a G_IMPLICIT_DEF.
+class GImplicitDef : public GenericMachineInstr {
+public:
+  static bool classof(const MachineInstr *MI) {
+    return MI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF;
+  }
+};
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
\ No newline at end of file

diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index d8cebee063a49..e813d030eec32 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -99,6 +99,21 @@ inline GFCstAndRegMatch m_GFCst(Optional<FPValueAndVReg> &FPValReg) {
   return GFCstAndRegMatch(FPValReg);
 }
 
+struct GFCstOrSplatGFCstMatch {
+  Optional<FPValueAndVReg> &FPValReg;
+  GFCstOrSplatGFCstMatch(Optional<FPValueAndVReg> &FPValReg)
+      : FPValReg(FPValReg) {}
+  bool match(const MachineRegisterInfo &MRI, Register Reg) {
+    return (FPValReg = getFConstantSplat(Reg, MRI)) ||
+           (FPValReg = getFConstantVRegValWithLookThrough(Reg, MRI));
+  };
+};
+
+inline GFCstOrSplatGFCstMatch
+m_GFCstOrSplat(Optional<FPValueAndVReg> &FPValReg) {
+  return GFCstOrSplatGFCstMatch(FPValReg);
+}
+
 /// Matcher for a specific constant value.
 struct SpecificConstantMatch {
   int64_t RequestedVal;

diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index daaa099115486..a6e6e4942d22d 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -357,15 +357,23 @@ Optional<int> getSplatIndex(MachineInstr &MI);
 Optional<int64_t> getBuildVectorConstantSplat(const MachineInstr &MI,
                                               const MachineRegisterInfo &MRI);
 
+/// Returns a floating point scalar constant of a build vector splat if it
+/// exists. When \p AllowUndef == true some elements can be undef but not all.
+Optional<FPValueAndVReg> getFConstantSplat(Register VReg,
+                                           const MachineRegisterInfo &MRI,
+                                           bool AllowUndef = true);
+
 /// Return true if the specified instruction is a G_BUILD_VECTOR or
 /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
 bool isBuildVectorAllZeros(const MachineInstr &MI,
-                           const MachineRegisterInfo &MRI);
+                           const MachineRegisterInfo &MRI,
+                           bool AllowUndef = false);
 
 /// Return true if the specified instruction is a G_BUILD_VECTOR or
 /// G_BUILD_VECTOR_TRUNC where all of the elements are ~0 or undef.
 bool isBuildVectorAllOnes(const MachineInstr &MI,
-                          const MachineRegisterInfo &MRI);
+                          const MachineRegisterInfo &MRI,
+                          bool AllowUndef = false);
 
 /// \returns a value when \p MI is a vector splat. The splat can be either a
 /// Register or a constant.

diff  --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 177d4025bbb8f..3c09df0b69703 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/Optional.h"
 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
 #include "llvm/CodeGen/GlobalISel/RegisterBankInfo.h"
 #include "llvm/CodeGen/MachineInstr.h"
@@ -924,53 +925,81 @@ static bool isBuildVectorOp(unsigned Opcode) {
          Opcode == TargetOpcode::G_BUILD_VECTOR_TRUNC;
 }
 
-// TODO: Handle mixed undef elements.
-static bool isBuildVectorConstantSplat(const MachineInstr &MI,
-                                       const MachineRegisterInfo &MRI,
-                                       int64_t SplatValue) {
-  if (!isBuildVectorOp(MI.getOpcode()))
-    return false;
+namespace {
 
-  const unsigned NumOps = MI.getNumOperands();
-  for (unsigned I = 1; I != NumOps; ++I) {
-    Register Element = MI.getOperand(I).getReg();
-    if (!mi_match(Element, MRI, m_SpecificICst(SplatValue)))
-      return false;
+Optional<ValueAndVReg> getAnyConstantSplat(Register VReg,
+                                           const MachineRegisterInfo &MRI,
+                                           bool AllowUndef) {
+  MachineInstr *MI = getDefIgnoringCopies(VReg, MRI);
+  if (!MI)
+    return None;
+
+  if (!isBuildVectorOp(MI->getOpcode()))
+    return None;
+
+  Optional<ValueAndVReg> SplatValAndReg = None;
+  for (MachineOperand &Op : MI->uses()) {
+    Register Element = Op.getReg();
+    auto ElementValAndReg =
+        getAnyConstantVRegValWithLookThrough(Element, MRI, true, true);
+
+    // If AllowUndef, treat undef as value that will result in a constant splat.
+    if (!ElementValAndReg) {
+      if (AllowUndef && isa<GImplicitDef>(MRI.getVRegDef(Element)))
+        continue;
+      return None;
+    }
+
+    // Record splat value
+    if (!SplatValAndReg)
+      SplatValAndReg = ElementValAndReg;
+
+    // Different constant then the one already recorded, not a constant splat.
+    if (SplatValAndReg->Value != ElementValAndReg->Value)
+      return None;
   }
 
-  return true;
+  return SplatValAndReg;
 }
 
+bool isBuildVectorConstantSplat(const MachineInstr &MI,
+                                const MachineRegisterInfo &MRI,
+                                int64_t SplatValue, bool AllowUndef) {
+  if (auto SplatValAndReg =
+          getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, AllowUndef))
+    return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue));
+  return false;
+}
+
+} // end anonymous namespace
+
 Optional<int64_t>
 llvm::getBuildVectorConstantSplat(const MachineInstr &MI,
                                   const MachineRegisterInfo &MRI) {
-  if (!isBuildVectorOp(MI.getOpcode()))
-    return None;
-
-  const unsigned NumOps = MI.getNumOperands();
-  Optional<int64_t> Scalar;
-  for (unsigned I = 1; I != NumOps; ++I) {
-    Register Element = MI.getOperand(I).getReg();
-    int64_t ElementValue;
-    if (!mi_match(Element, MRI, m_ICst(ElementValue)))
-      return None;
-    if (!Scalar)
-      Scalar = ElementValue;
-    else if (*Scalar != ElementValue)
-      return None;
-  }
+  if (auto SplatValAndReg =
+          getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false))
+    return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI);
+  return None;
+}
 
-  return Scalar;
+Optional<FPValueAndVReg> llvm::getFConstantSplat(Register VReg,
+                                                 const MachineRegisterInfo &MRI,
+                                                 bool AllowUndef) {
+  if (auto SplatValAndReg = getAnyConstantSplat(VReg, MRI, AllowUndef))
+    return getFConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI);
+  return None;
 }
 
 bool llvm::isBuildVectorAllZeros(const MachineInstr &MI,
-                                 const MachineRegisterInfo &MRI) {
-  return isBuildVectorConstantSplat(MI, MRI, 0);
+                                 const MachineRegisterInfo &MRI,
+                                 bool AllowUndef) {
+  return isBuildVectorConstantSplat(MI, MRI, 0, AllowUndef);
 }
 
 bool llvm::isBuildVectorAllOnes(const MachineInstr &MI,
-                                const MachineRegisterInfo &MRI) {
-  return isBuildVectorConstantSplat(MI, MRI, -1);
+                                const MachineRegisterInfo &MRI,
+                                bool AllowUndef) {
+  return isBuildVectorConstantSplat(MI, MRI, -1, AllowUndef);
 }
 
 Optional<RegOrConstant> llvm::getVectorSplat(const MachineInstr &MI,

diff  --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index 9ebb4b1cc54f0..b5f4e2266b07b 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -574,6 +574,57 @@ TEST_F(AArch64GISelMITest, MatchFPOrIntConst) {
   EXPECT_EQ(FPOne, FValReg->VReg);
 }
 
+TEST_F(AArch64GISelMITest, MatchConstantSplat) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT s64 = LLT::scalar(64);
+  LLT v4s64 = LLT::fixed_vector(4, 64);
+
+  Register FPOne = B.buildFConstant(s64, 1.0).getReg(0);
+  Register FPZero = B.buildFConstant(s64, 0.0).getReg(0);
+  Register Undef = B.buildUndef(s64).getReg(0);
+  Optional<FPValueAndVReg> FValReg;
+
+  // GFCstOrSplatGFCstMatch allows undef as part of splat. Undef often comes
+  // from padding to legalize into available operation and then ignore added
+  // elements e.g. v3s64 to v4s64.
+
+  EXPECT_TRUE(mi_match(FPZero, *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+  EXPECT_EQ(FPZero, FValReg->VReg);
+
+  EXPECT_FALSE(mi_match(Undef, *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+
+  auto ZeroSplat = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPZero});
+  EXPECT_TRUE(
+      mi_match(ZeroSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+  EXPECT_EQ(FPZero, FValReg->VReg);
+
+  auto ZeroUndef = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Undef});
+  EXPECT_TRUE(
+      mi_match(ZeroUndef.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+  EXPECT_EQ(FPZero, FValReg->VReg);
+
+  // All undefs are not constant splat.
+  auto UndefSplat = B.buildBuildVector(v4s64, {Undef, Undef, Undef, Undef});
+  EXPECT_FALSE(
+      mi_match(UndefSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+
+  auto ZeroOne = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPOne});
+  EXPECT_FALSE(
+      mi_match(ZeroOne.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+
+  auto NonConstantSplat =
+      B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
+  EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI,
+                        GFCstOrSplatGFCstMatch(FValReg)));
+
+  auto Mixed = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Copies[0]});
+  EXPECT_FALSE(
+      mi_match(Mixed.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg)));
+}
+
 TEST_F(AArch64GISelMITest, MatchNeg) {
   setUp();
   if (!TM)


        


More information about the llvm-commits mailing list