[llvm-branch-commits] [llvm] de373ef - [SelectionDAG] Extend immAll(Ones|Zeros)V to handle ISD::SPLAT_VECTOR

Fraser Cormack via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 9 09:16:24 PST 2021


Author: Fraser Cormack
Date: 2021-01-09T17:05:31Z
New Revision: de373ef779880e923636d90cdb277e4db84c7479

URL: https://github.com/llvm/llvm-project/commit/de373ef779880e923636d90cdb277e4db84c7479
DIFF: https://github.com/llvm/llvm-project/commit/de373ef779880e923636d90cdb277e4db84c7479.diff

LOG: [SelectionDAG] Extend immAll(Ones|Zeros)V to handle ISD::SPLAT_VECTOR

The TableGen immAllOnesV and immAllZerosV helpers implicitly wrapped the
ISD::isBuildVectorAll(Ones|Zeros) helper functions. This was inhibiting
their use for targets such as RISC-V which use ISD::SPLAT_VECTOR. In
particular, RISC-V had to define its own 'vnot' fragment.

In order to extend the scope of these nodes to include support for
ISD::SPLAT_VECTOR, two new ISD predicate functions have been introduced:
ISD::isConstantSplatVectorAll(Ones|Zeros). These effectively supersede
the older "isBuildVector" predicates, which are now simple wrappers for
the new functions. They pass a defaulted boolean toggle which preserves
the old behaviour. It is hoped that in time all call-sites can be ported
to the "isConstantSplatVector" functions.

While the use of ISD::isBuildVectorAll(Ones|Zeros) has not changed, the
behaviour of the TableGen immAll(Ones|Zeros)V **has**. To test the new
functionality, the custom RISC-V TableGen fragment has been removed and
replaced with the built-in 'vnot'. To test their use as pattern-roots, two
splat patterns have been updated accordingly.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D94223

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
    llvm/utils/TableGen/DAGISelMatcher.h
    llvm/utils/TableGen/DAGISelMatcherGen.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 5926c52c51d9..3d122402bf57 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -85,29 +85,42 @@ namespace ISD {
 
   /// Node predicates
 
-  /// If N is a BUILD_VECTOR node whose elements are all the same constant or
-  /// undefined, return true and return the constant value in \p SplatValue.
-  bool isConstantSplatVector(const SDNode *N, APInt &SplatValue);
-
-  /// Return true if the specified node is a BUILD_VECTOR where all of the
-  /// elements are ~0 or undef.
-  bool isBuildVectorAllOnes(const SDNode *N);
-
-  /// Return true if the specified node is a BUILD_VECTOR where all of the
-  /// elements are 0 or undef.
-  bool isBuildVectorAllZeros(const SDNode *N);
-
-  /// Return true if the specified node is a BUILD_VECTOR node of all
-  /// ConstantSDNode or undef.
-  bool isBuildVectorOfConstantSDNodes(const SDNode *N);
-
-  /// Return true if the specified node is a BUILD_VECTOR node of all
-  /// ConstantFPSDNode or undef.
-  bool isBuildVectorOfConstantFPSDNodes(const SDNode *N);
-
-  /// Return true if the node has at least one operand and all operands of the
-  /// specified node are ISD::UNDEF.
-  bool allOperandsUndef(const SDNode *N);
+/// If N is a BUILD_VECTOR or SPLAT_VECTOR node whose elements are all the
+/// same constant or undefined, return true and return the constant value in
+/// \p SplatValue.
+bool isConstantSplatVector(const SDNode *N, APInt &SplatValue);
+
+/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where
+/// all of the elements are ~0 or undef. If \p BuildVectorOnly is set to
+/// true, it only checks BUILD_VECTOR.
+bool isConstantSplatVectorAllOnes(const SDNode *N,
+                                  bool BuildVectorOnly = false);
+
+/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where
+/// all of the elements are 0 or undef. If \p BuildVectorOnly is set to true, it
+/// only checks BUILD_VECTOR.
+bool isConstantSplatVectorAllZeros(const SDNode *N,
+                                   bool BuildVectorOnly = false);
+
+/// Return true if the specified node is a BUILD_VECTOR where all of the
+/// elements are ~0 or undef.
+bool isBuildVectorAllOnes(const SDNode *N);
+
+/// Return true if the specified node is a BUILD_VECTOR where all of the
+/// elements are 0 or undef.
+bool isBuildVectorAllZeros(const SDNode *N);
+
+/// Return true if the specified node is a BUILD_VECTOR node of all
+/// ConstantSDNode or undef.
+bool isBuildVectorOfConstantSDNodes(const SDNode *N);
+
+/// Return true if the specified node is a BUILD_VECTOR node of all
+/// ConstantFPSDNode or undef.
+bool isBuildVectorOfConstantFPSDNodes(const SDNode *N);
+
+/// Return true if the node has at least one operand and all operands of the
+/// specified node are ISD::UNDEF.
+bool allOperandsUndef(const SDNode *N);
 
 } // end namespace ISD
 

diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index a1e961aa9cb5..a09feca6ca9b 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -909,11 +909,13 @@ class FPImmLeaf<ValueType vt, code pred, SDNodeXForm xform = NOOP_SDNodeXForm>
 def vtInt      : PatLeaf<(vt),  [{ return N->getVT().isInteger(); }]>;
 def vtFP       : PatLeaf<(vt),  [{ return N->getVT().isFloatingPoint(); }]>;
 
-// Use ISD::isBuildVectorAllOnes or ISD::isBuildVectorAllZeros to look for
-// the corresponding build_vector. Will look through bitcasts except when used
-// as a pattern root.
-def immAllOnesV; // ISD::isBuildVectorAllOnes
-def immAllZerosV; // ISD::isBuildVectorAllZeros
+// Use ISD::isConstantSplatVectorAllOnes or ISD::isConstantSplatVectorAllZeros
+// to look for the corresponding build_vector or splat_vector. Will look through
+// bitcasts and check for either opcode, except when used as a pattern root.
+// When used as a pattern root, only fixed-length build_vector and scalable
+// splat_vector are supported.
+def immAllOnesV; // ISD::isConstantSplatVectorAllOnes
+def immAllZerosV; // ISD::isConstantSplatVectorAllZeros
 
 // Other helper fragments.
 def not  : PatFrag<(ops node:$in), (xor node:$in, -1)>;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b13cb4f019a8..14496b57378d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -164,11 +164,16 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) {
 // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be
 // specializations of the more general isConstantSplatVector()?
 
-bool ISD::isBuildVectorAllOnes(const SDNode *N) {
+bool ISD::isConstantSplatVectorAllOnes(const SDNode *N, bool BuildVectorOnly) {
   // Look through a bit convert.
   while (N->getOpcode() == ISD::BITCAST)
     N = N->getOperand(0).getNode();
 
+  if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) {
+    APInt SplatVal;
+    return isConstantSplatVector(N, SplatVal) && SplatVal.isAllOnesValue();
+  }
+
   if (N->getOpcode() != ISD::BUILD_VECTOR) return false;
 
   unsigned i = 0, e = N->getNumOperands();
@@ -208,11 +213,16 @@ bool ISD::isBuildVectorAllOnes(const SDNode *N) {
   return true;
 }
 
-bool ISD::isBuildVectorAllZeros(const SDNode *N) {
+bool ISD::isConstantSplatVectorAllZeros(const SDNode *N, bool BuildVectorOnly) {
   // Look through a bit convert.
   while (N->getOpcode() == ISD::BITCAST)
     N = N->getOperand(0).getNode();
 
+  if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) {
+    APInt SplatVal;
+    return isConstantSplatVector(N, SplatVal) && SplatVal.isNullValue();
+  }
+
   if (N->getOpcode() != ISD::BUILD_VECTOR) return false;
 
   bool IsAllUndef = true;
@@ -245,6 +255,14 @@ bool ISD::isBuildVectorAllZeros(const SDNode *N) {
   return true;
 }
 
+bool ISD::isBuildVectorAllOnes(const SDNode *N) {
+  return isConstantSplatVectorAllOnes(N, /*BuildVectorOnly*/ true);
+}
+
+bool ISD::isBuildVectorAllZeros(const SDNode *N) {
+  return isConstantSplatVectorAllZeros(N, /*BuildVectorOnly*/ true);
+}
+
 bool ISD::isBuildVectorOfConstantSDNodes(const SDNode *N) {
   if (N->getOpcode() != ISD::BUILD_VECTOR)
     return false;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index c26dbfa6567c..7bae5048fc0e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -3202,10 +3202,12 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
       if (!::CheckOrImm(MatcherTable, MatcherIndex, N, *this)) break;
       continue;
     case OPC_CheckImmAllOnesV:
-      if (!ISD::isBuildVectorAllOnes(N.getNode())) break;
+      if (!ISD::isConstantSplatVectorAllOnes(N.getNode()))
+        break;
       continue;
     case OPC_CheckImmAllZerosV:
-      if (!ISD::isBuildVectorAllZeros(N.getNode())) break;
+      if (!ISD::isConstantSplatVectorAllZeros(N.getNode()))
+        break;
       continue;
 
     case OPC_CheckFoldableChainNode: {

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 208e50168897..e158b632aa73 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -35,11 +35,6 @@ def SplatPat       : ComplexPattern<vAny, 1, "selectVSplat", [], [], 1>;
 def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", []>;
 def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", []>;
 
-// A mask-vector version of the standard 'vnot' fragment but using splat_vector
-// rather than (the implicit) build_vector
-def riscv_m_vnot : PatFrag<(ops node:$in),
-                           (xor node:$in, (splat_vector (XLenVT 1)))>;
-
 multiclass VPatUSLoadStoreSDNode<LLVMType type,
                                  LLVMType mask_type,
                                  int sew,
@@ -198,20 +193,20 @@ foreach mti = AllMasks in {
             (!cast<Instruction>("PseudoVMXOR_MM_"#mti.LMul.MX)
                  VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
 
-  def : Pat<(mti.Mask (riscv_m_vnot (and VR:$rs1, VR:$rs2))),
+  def : Pat<(mti.Mask (vnot (and VR:$rs1, VR:$rs2))),
             (!cast<Instruction>("PseudoVMNAND_MM_"#mti.LMul.MX)
                  VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
-  def : Pat<(mti.Mask (riscv_m_vnot (or VR:$rs1, VR:$rs2))),
+  def : Pat<(mti.Mask (vnot (or VR:$rs1, VR:$rs2))),
             (!cast<Instruction>("PseudoVMNOR_MM_"#mti.LMul.MX)
                  VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
-  def : Pat<(mti.Mask (riscv_m_vnot (xor VR:$rs1, VR:$rs2))),
+  def : Pat<(mti.Mask (vnot (xor VR:$rs1, VR:$rs2))),
             (!cast<Instruction>("PseudoVMXNOR_MM_"#mti.LMul.MX)
                  VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
 
-  def : Pat<(mti.Mask (and VR:$rs1, (riscv_m_vnot VR:$rs2))),
+  def : Pat<(mti.Mask (and VR:$rs1, (vnot VR:$rs2))),
             (!cast<Instruction>("PseudoVMANDNOT_MM_"#mti.LMul.MX)
                  VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
-  def : Pat<(mti.Mask (or VR:$rs1, (riscv_m_vnot VR:$rs2))),
+  def : Pat<(mti.Mask (or VR:$rs1, (vnot VR:$rs2))),
             (!cast<Instruction>("PseudoVMORNOT_MM_"#mti.LMul.MX)
                  VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
 }
@@ -233,9 +228,9 @@ foreach vti = AllIntegerVectors in {
 }
 
 foreach mti = AllMasks in {
-  def : Pat<(mti.Mask (splat_vector (XLenVT 1))),
+  def : Pat<(mti.Mask immAllOnesV),
             (!cast<Instruction>("PseudoVMSET_M_"#mti.BX) VLMax, mti.SEW)>;
-  def : Pat<(mti.Mask (splat_vector (XLenVT 0))),
+  def : Pat<(mti.Mask immAllZerosV),
             (!cast<Instruction>("PseudoVMCLR_M_"#mti.BX) VLMax, mti.SEW)>;
 }
 } // Predicates = [HasStdExtV]

diff  --git a/llvm/utils/TableGen/DAGISelMatcher.h b/llvm/utils/TableGen/DAGISelMatcher.h
index dca1865b22e0..3a920d1d7f22 100644
--- a/llvm/utils/TableGen/DAGISelMatcher.h
+++ b/llvm/utils/TableGen/DAGISelMatcher.h
@@ -763,8 +763,8 @@ class CheckOrImmMatcher : public Matcher {
   }
 };
 
-/// CheckImmAllOnesVMatcher - This check if the current node is an build vector
-/// of all ones.
+/// CheckImmAllOnesVMatcher - This checks if the current node is a build_vector
+/// or splat_vector of all ones.
 class CheckImmAllOnesVMatcher : public Matcher {
 public:
   CheckImmAllOnesVMatcher() : Matcher(CheckImmAllOnesV) {}
@@ -779,8 +779,8 @@ class CheckImmAllOnesVMatcher : public Matcher {
   bool isContradictoryImpl(const Matcher *M) const override;
 };
 
-/// CheckImmAllZerosVMatcher - This check if the current node is an build vector
-/// of all zeros.
+/// CheckImmAllZerosVMatcher - This checks if the current node is a
+/// build_vector or splat_vector of all zeros.
 class CheckImmAllZerosVMatcher : public Matcher {
 public:
   CheckImmAllZerosVMatcher() : Matcher(CheckImmAllZerosV) {}

diff  --git a/llvm/utils/TableGen/DAGISelMatcherGen.cpp b/llvm/utils/TableGen/DAGISelMatcherGen.cpp
index 792bf17690c1..f7415b87e1c0 100644
--- a/llvm/utils/TableGen/DAGISelMatcherGen.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherGen.cpp
@@ -282,7 +282,9 @@ void MatcherGen::EmitLeafMatchCode(const TreePatternNode *N) {
     // check to ensure that this gets folded into the normal top-level
     // OpcodeSwitch.
     if (N == Pattern.getSrcPattern()) {
-      const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector"));
+      MVT VT = N->getSimpleType(0);
+      StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector";
+      const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name));
       AddMatcher(new CheckOpcodeMatcher(NI));
     }
     return AddMatcher(new CheckImmAllOnesVMatcher());
@@ -292,7 +294,9 @@ void MatcherGen::EmitLeafMatchCode(const TreePatternNode *N) {
     // check to ensure that this gets folded into the normal top-level
     // OpcodeSwitch.
     if (N == Pattern.getSrcPattern()) {
-      const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector"));
+      MVT VT = N->getSimpleType(0);
+      StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector";
+      const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name));
       AddMatcher(new CheckOpcodeMatcher(NI));
     }
     return AddMatcher(new CheckImmAllZerosVMatcher());


        


More information about the llvm-branch-commits mailing list