[llvm] [RISCV] Move vector pseudo hasAllNBitUsers switch into RISCVBaseInfo.{h,cpp}. NFC (PR #67593)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 27 13:09:42 PDT 2023


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/67593

>From e71e7fbab3099cd37f4ff64be7b16d4d31dae7c7 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 27 Sep 2023 12:20:54 +0100
Subject: [PATCH 1/2] [RISCV] Move vector pseudo hasAllNBitUsers switch into
 RISCVBaseInfo.{h,cpp}. NFC

The handling for vector pseudos in hasAllNBitUsers is duplicated across
RISCVISelDAGToDAG and RISCVOptWInstrs. This deduplicates it between the two,
with the common denominator between the two call sites being the opcode and
SEW: We need to handle extracting these separately since one operates at the
SelectionDAG level and the other at the MachineInstr level.
---
 .../RISCV/MCTargetDesc/RISCVBaseInfo.cpp      | 112 ++++++++++++
 .../Target/RISCV/MCTargetDesc/RISCVBaseInfo.h |   5 +
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp   | 170 +++---------------
 llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp     | 154 ++--------------
 4 files changed, 158 insertions(+), 283 deletions(-)

diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
index 0a42c6faee29008..95cea0c61acfd5d 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
@@ -130,6 +130,118 @@ parseFeatureBits(bool IsRV64, const FeatureBitset &FeatureBits) {
 
 } // namespace RISCVFeatures
 
+bool RISCVII::vectorInstUsesNBitsOfScalarOp(uint16_t Opcode, unsigned Bits,
+                                            unsigned Log2SEW) {
+  // TODO: Handle Zvbb instructions
+  switch (Opcode) {
+  default:
+    return false;
+
+  // 11.6. Vector Single-Width Shift Instructions
+  case RISCV::VSLL_VX:
+  case RISCV::VSRL_VX:
+  case RISCV::VSRA_VX:
+  // 12.4. Vector Single-Width Scaling Shift Instructions
+  case RISCV::VSSRL_VX:
+  case RISCV::VSSRA_VX:
+    // Only the low lg2(SEW) bits of the shift-amount value are used.
+    return Log2SEW <= Bits;
+
+  // 11.7 Vector Narrowing Integer Right Shift Instructions
+  case RISCV::VNSRL_WX:
+  case RISCV::VNSRA_WX:
+  // 12.5. Vector Narrowing Fixed-Point Clip Instructions
+  case RISCV::VNCLIPU_WX:
+  case RISCV::VNCLIP_WX:
+    // Only the low lg2(2*SEW) bits of the shift-amount value are used.
+    return (Log2SEW + 1) <= Bits;
+
+  // 11.1. Vector Single-Width Integer Add and Subtract
+  case RISCV::VADD_VX:
+  case RISCV::VSUB_VX:
+  case RISCV::VRSUB_VX:
+  // 11.2. Vector Widening Integer Add/Subtract
+  case RISCV::VWADDU_VX:
+  case RISCV::VWSUBU_VX:
+  case RISCV::VWADD_VX:
+  case RISCV::VWSUB_VX:
+  case RISCV::VWADDU_WX:
+  case RISCV::VWSUBU_WX:
+  case RISCV::VWADD_WX:
+  case RISCV::VWSUB_WX:
+  // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
+  case RISCV::VADC_VXM:
+  case RISCV::VADC_VIM:
+  case RISCV::VMADC_VXM:
+  case RISCV::VMADC_VIM:
+  case RISCV::VMADC_VX:
+  case RISCV::VSBC_VXM:
+  case RISCV::VMSBC_VXM:
+  case RISCV::VMSBC_VX:
+  // 11.5 Vector Bitwise Logical Instructions
+  case RISCV::VAND_VX:
+  case RISCV::VOR_VX:
+  case RISCV::VXOR_VX:
+  // 11.8. Vector Integer Compare Instructions
+  case RISCV::VMSEQ_VX:
+  case RISCV::VMSNE_VX:
+  case RISCV::VMSLTU_VX:
+  case RISCV::VMSLT_VX:
+  case RISCV::VMSLEU_VX:
+  case RISCV::VMSLE_VX:
+  case RISCV::VMSGTU_VX:
+  case RISCV::VMSGT_VX:
+  // 11.9. Vector Integer Min/Max Instructions
+  case RISCV::VMINU_VX:
+  case RISCV::VMIN_VX:
+  case RISCV::VMAXU_VX:
+  case RISCV::VMAX_VX:
+  // 11.10. Vector Single-Width Integer Multiply Instructions
+  case RISCV::VMUL_VX:
+  case RISCV::VMULH_VX:
+  case RISCV::VMULHU_VX:
+  case RISCV::VMULHSU_VX:
+  // 11.11. Vector Integer Divide Instructions
+  case RISCV::VDIVU_VX:
+  case RISCV::VDIV_VX:
+  case RISCV::VREMU_VX:
+  case RISCV::VREM_VX:
+  // 11.12. Vector Widening Integer Multiply Instructions
+  case RISCV::VWMUL_VX:
+  case RISCV::VWMULU_VX:
+  case RISCV::VWMULSU_VX:
+  // 11.13. Vector Single-Width Integer Multiply-Add Instructions
+  case RISCV::VMACC_VX:
+  case RISCV::VNMSAC_VX:
+  case RISCV::VMADD_VX:
+  case RISCV::VNMSUB_VX:
+  // 11.14. Vector Widening Integer Multiply-Add Instructions
+  case RISCV::VWMACCU_VX:
+  case RISCV::VWMACC_VX:
+  case RISCV::VWMACCSU_VX:
+  case RISCV::VWMACCUS_VX:
+  // 11.15. Vector Integer Merge Instructions
+  case RISCV::VMERGE_VXM:
+  // 11.16. Vector Integer Move Instructions
+  case RISCV::VMV_V_X:
+  // 12.1. Vector Single-Width Saturating Add and Subtract
+  case RISCV::VSADDU_VX:
+  case RISCV::VSADD_VX:
+  case RISCV::VSSUBU_VX:
+  case RISCV::VSSUB_VX:
+  // 12.2. Vector Single-Width Averaging Add and Subtract
+  case RISCV::VAADDU_VX:
+  case RISCV::VAADD_VX:
+  case RISCV::VASUBU_VX:
+  case RISCV::VASUB_VX:
+  // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
+  case RISCV::VSMUL_VX:
+  // 16.1. Integer Scalar Move Instructions
+  case RISCV::VMV_S_X:
+    return (1 << Log2SEW) <= Bits;
+  }
+}
+
 // Encode VTYPE into the binary format used by the the VSETVLI instruction which
 // is used by our MC layer representation.
 //
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 20ff26a39dc3b30..222d4e9eef674ec 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -241,6 +241,11 @@ static inline bool isFirstDefTiedToFirstUse(const MCInstrDesc &Desc) {
          Desc.getOperandConstraint(Desc.getNumDefs(), MCOI::TIED_TO) == 0;
 }
 
+// Returns true if the .vx vector instruction \p Opcode only uses the lower \p
+// Bits for a given SEW.
+bool vectorInstUsesNBitsOfScalarOp(uint16_t Opcode, unsigned Bits,
+                                   unsigned Log2SEW);
+
 // RISC-V Specific Machine Operand Flags
 enum {
   MO_None = 0,
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 283ab1feda7eca5..b5f91c6bf70b061 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -2753,148 +2753,6 @@ bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt,
   return false;
 }
 
-static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,
-                                        unsigned Bits,
-                                        const TargetInstrInfo *TII) {
-  const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
-      RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode());
-
-  if (!PseudoInfo)
-    return false;
-
-  const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
-  const uint64_t TSFlags = MCID.TSFlags;
-  if (!RISCVII::hasSEWOp(TSFlags))
-    return false;
-  assert(RISCVII::hasVLOp(TSFlags));
-
-  bool HasGlueOp = User->getGluedNode() != nullptr;
-  unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
-  bool HasChainOp = User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
-  bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
-  unsigned VLIdx =
-      User->getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
-  const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
-
-  if (UserOpNo == VLIdx)
-    return false;
-
-  // TODO: Handle Zvbb instructions
-  switch (PseudoInfo->BaseInstr) {
-  default:
-    return false;
-
-  // 11.6. Vector Single-Width Shift Instructions
-  case RISCV::VSLL_VX:
-  case RISCV::VSRL_VX:
-  case RISCV::VSRA_VX:
-  // 12.4. Vector Single-Width Scaling Shift Instructions
-  case RISCV::VSSRL_VX:
-  case RISCV::VSSRA_VX:
-    // Only the low lg2(SEW) bits of the shift-amount value are used.
-    if (Bits < Log2SEW)
-      return false;
-    break;
-
-  // 11.7 Vector Narrowing Integer Right Shift Instructions
-  case RISCV::VNSRL_WX:
-  case RISCV::VNSRA_WX:
-  // 12.5. Vector Narrowing Fixed-Point Clip Instructions
-  case RISCV::VNCLIPU_WX:
-  case RISCV::VNCLIP_WX:
-    // Only the low lg2(2*SEW) bits of the shift-amount value are used.
-    if (Bits < Log2SEW + 1)
-      return false;
-    break;
-
-  // 11.1. Vector Single-Width Integer Add and Subtract
-  case RISCV::VADD_VX:
-  case RISCV::VSUB_VX:
-  case RISCV::VRSUB_VX:
-  // 11.2. Vector Widening Integer Add/Subtract
-  case RISCV::VWADDU_VX:
-  case RISCV::VWSUBU_VX:
-  case RISCV::VWADD_VX:
-  case RISCV::VWSUB_VX:
-  case RISCV::VWADDU_WX:
-  case RISCV::VWSUBU_WX:
-  case RISCV::VWADD_WX:
-  case RISCV::VWSUB_WX:
-  // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
-  case RISCV::VADC_VXM:
-  case RISCV::VADC_VIM:
-  case RISCV::VMADC_VXM:
-  case RISCV::VMADC_VIM:
-  case RISCV::VMADC_VX:
-  case RISCV::VSBC_VXM:
-  case RISCV::VMSBC_VXM:
-  case RISCV::VMSBC_VX:
-  // 11.5 Vector Bitwise Logical Instructions
-  case RISCV::VAND_VX:
-  case RISCV::VOR_VX:
-  case RISCV::VXOR_VX:
-  // 11.8. Vector Integer Compare Instructions
-  case RISCV::VMSEQ_VX:
-  case RISCV::VMSNE_VX:
-  case RISCV::VMSLTU_VX:
-  case RISCV::VMSLT_VX:
-  case RISCV::VMSLEU_VX:
-  case RISCV::VMSLE_VX:
-  case RISCV::VMSGTU_VX:
-  case RISCV::VMSGT_VX:
-  // 11.9. Vector Integer Min/Max Instructions
-  case RISCV::VMINU_VX:
-  case RISCV::VMIN_VX:
-  case RISCV::VMAXU_VX:
-  case RISCV::VMAX_VX:
-  // 11.10. Vector Single-Width Integer Multiply Instructions
-  case RISCV::VMUL_VX:
-  case RISCV::VMULH_VX:
-  case RISCV::VMULHU_VX:
-  case RISCV::VMULHSU_VX:
-  // 11.11. Vector Integer Divide Instructions
-  case RISCV::VDIVU_VX:
-  case RISCV::VDIV_VX:
-  case RISCV::VREMU_VX:
-  case RISCV::VREM_VX:
-  // 11.12. Vector Widening Integer Multiply Instructions
-  case RISCV::VWMUL_VX:
-  case RISCV::VWMULU_VX:
-  case RISCV::VWMULSU_VX:
-  // 11.13. Vector Single-Width Integer Multiply-Add Instructions
-  case RISCV::VMACC_VX:
-  case RISCV::VNMSAC_VX:
-  case RISCV::VMADD_VX:
-  case RISCV::VNMSUB_VX:
-  // 11.14. Vector Widening Integer Multiply-Add Instructions
-  case RISCV::VWMACCU_VX:
-  case RISCV::VWMACC_VX:
-  case RISCV::VWMACCSU_VX:
-  case RISCV::VWMACCUS_VX:
-  // 11.15. Vector Integer Merge Instructions
-  case RISCV::VMERGE_VXM:
-  // 11.16. Vector Integer Move Instructions
-  case RISCV::VMV_V_X:
-  // 12.1. Vector Single-Width Saturating Add and Subtract
-  case RISCV::VSADDU_VX:
-  case RISCV::VSADD_VX:
-  case RISCV::VSSUBU_VX:
-  case RISCV::VSSUB_VX:
-  // 12.2. Vector Single-Width Averaging Add and Subtract
-  case RISCV::VAADDU_VX:
-  case RISCV::VAADD_VX:
-  case RISCV::VASUBU_VX:
-  case RISCV::VASUB_VX:
-  // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
-  case RISCV::VSMUL_VX:
-  // 16.1. Integer Scalar Move Instructions
-  case RISCV::VMV_S_X:
-    if (Bits < (1 << Log2SEW))
-      return false;
-  }
-  return true;
-}
-
 // Return true if all users of this SDNode* only consume the lower \p Bits.
 // This can be used to form W instructions for add/sub/mul/shl even when the
 // root isn't a sext_inreg. This can allow the ADDW/SUBW/MULW/SLLIW to CSE if
@@ -2925,10 +2783,32 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits,
 
     // TODO: Add more opcodes?
     switch (User->getMachineOpcode()) {
-    default:
-      if (vectorPseudoHasAllNBitUsers(User, UI.getOperandNo(), Bits, TII))
-        break;
+    default: {
+      if (const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
+              RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode())) {
+
+        const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
+        if (!RISCVII::hasSEWOp(MCID.TSFlags))
+          return false;
+        assert(RISCVII::hasVLOp(MCID.TSFlags));
+
+        bool HasGlueOp = User->getGluedNode() != nullptr;
+        unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
+        bool HasChainOp =
+            User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
+        bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags);
+        unsigned VLIdx = User->getNumOperands() - HasVecPolicyOp - HasChainOp -
+                         HasGlueOp - 2;
+        const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
+
+        if (UI.getOperandNo() == VLIdx)
+          return false;
+        if (RISCVII::vectorInstUsesNBitsOfScalarOp(PseudoInfo->BaseInstr, Bits,
+                                                   Log2SEW))
+          break;
+      }
       return false;
+    }
     case RISCV::ADDW:
     case RISCV::ADDIW:
     case RISCV::SUBW:
diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
index 0cbdfa84640bf91..56aa5589bc19a29 100644
--- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
@@ -78,141 +78,6 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
   return new RISCVOptWInstrs();
 }
 
-static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
-                                        unsigned Bits) {
-  const MachineInstr &MI = *UserOp.getParent();
-  const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
-      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
-
-  if (!PseudoInfo)
-    return false;
-
-  const MCInstrDesc &MCID = MI.getDesc();
-  const uint64_t TSFlags = MI.getDesc().TSFlags;
-  if (!RISCVII::hasSEWOp(TSFlags))
-    return false;
-  assert(RISCVII::hasVLOp(TSFlags));
-  const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
-
-  if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
-    return false;
-
-  // TODO: Handle Zvbb instructions
-  switch (PseudoInfo->BaseInstr) {
-  default:
-    return false;
-
-  // 11.6. Vector Single-Width Shift Instructions
-  case RISCV::VSLL_VX:
-  case RISCV::VSRL_VX:
-  case RISCV::VSRA_VX:
-  // 12.4. Vector Single-Width Scaling Shift Instructions
-  case RISCV::VSSRL_VX:
-  case RISCV::VSSRA_VX:
-    // Only the low lg2(SEW) bits of the shift-amount value are used.
-    if (Bits < Log2SEW)
-      return false;
-    break;
-
-  // 11.7 Vector Narrowing Integer Right Shift Instructions
-  case RISCV::VNSRL_WX:
-  case RISCV::VNSRA_WX:
-  // 12.5. Vector Narrowing Fixed-Point Clip Instructions
-  case RISCV::VNCLIPU_WX:
-  case RISCV::VNCLIP_WX:
-    // Only the low lg2(2*SEW) bits of the shift-amount value are used.
-    if (Bits < Log2SEW + 1)
-      return false;
-    break;
-
-  // 11.1. Vector Single-Width Integer Add and Subtract
-  case RISCV::VADD_VX:
-  case RISCV::VSUB_VX:
-  case RISCV::VRSUB_VX:
-  // 11.2. Vector Widening Integer Add/Subtract
-  case RISCV::VWADDU_VX:
-  case RISCV::VWSUBU_VX:
-  case RISCV::VWADD_VX:
-  case RISCV::VWSUB_VX:
-  case RISCV::VWADDU_WX:
-  case RISCV::VWSUBU_WX:
-  case RISCV::VWADD_WX:
-  case RISCV::VWSUB_WX:
-  // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
-  case RISCV::VADC_VXM:
-  case RISCV::VADC_VIM:
-  case RISCV::VMADC_VXM:
-  case RISCV::VMADC_VIM:
-  case RISCV::VMADC_VX:
-  case RISCV::VSBC_VXM:
-  case RISCV::VMSBC_VXM:
-  case RISCV::VMSBC_VX:
-  // 11.5 Vector Bitwise Logical Instructions
-  case RISCV::VAND_VX:
-  case RISCV::VOR_VX:
-  case RISCV::VXOR_VX:
-  // 11.8. Vector Integer Compare Instructions
-  case RISCV::VMSEQ_VX:
-  case RISCV::VMSNE_VX:
-  case RISCV::VMSLTU_VX:
-  case RISCV::VMSLT_VX:
-  case RISCV::VMSLEU_VX:
-  case RISCV::VMSLE_VX:
-  case RISCV::VMSGTU_VX:
-  case RISCV::VMSGT_VX:
-  // 11.9. Vector Integer Min/Max Instructions
-  case RISCV::VMINU_VX:
-  case RISCV::VMIN_VX:
-  case RISCV::VMAXU_VX:
-  case RISCV::VMAX_VX:
-  // 11.10. Vector Single-Width Integer Multiply Instructions
-  case RISCV::VMUL_VX:
-  case RISCV::VMULH_VX:
-  case RISCV::VMULHU_VX:
-  case RISCV::VMULHSU_VX:
-  // 11.11. Vector Integer Divide Instructions
-  case RISCV::VDIVU_VX:
-  case RISCV::VDIV_VX:
-  case RISCV::VREMU_VX:
-  case RISCV::VREM_VX:
-  // 11.12. Vector Widening Integer Multiply Instructions
-  case RISCV::VWMUL_VX:
-  case RISCV::VWMULU_VX:
-  case RISCV::VWMULSU_VX:
-  // 11.13. Vector Single-Width Integer Multiply-Add Instructions
-  case RISCV::VMACC_VX:
-  case RISCV::VNMSAC_VX:
-  case RISCV::VMADD_VX:
-  case RISCV::VNMSUB_VX:
-  // 11.14. Vector Widening Integer Multiply-Add Instructions
-  case RISCV::VWMACCU_VX:
-  case RISCV::VWMACC_VX:
-  case RISCV::VWMACCSU_VX:
-  case RISCV::VWMACCUS_VX:
-  // 11.15. Vector Integer Merge Instructions
-  case RISCV::VMERGE_VXM:
-  // 11.16. Vector Integer Move Instructions
-  case RISCV::VMV_V_X:
-  // 12.1. Vector Single-Width Saturating Add and Subtract
-  case RISCV::VSADDU_VX:
-  case RISCV::VSADD_VX:
-  case RISCV::VSSUBU_VX:
-  case RISCV::VSSUB_VX:
-  // 12.2. Vector Single-Width Averaging Add and Subtract
-  case RISCV::VAADDU_VX:
-  case RISCV::VAADD_VX:
-  case RISCV::VASUBU_VX:
-  case RISCV::VASUB_VX:
-  // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
-  case RISCV::VSMUL_VX:
-  // 16.1. Integer Scalar Move Instructions
-  case RISCV::VMV_S_X:
-    if (Bits < (1 << Log2SEW))
-      return false;
-  }
-  return true;
-}
-
 // Checks if all users only demand the lower \p OrigBits of the original
 // instruction's result.
 // TODO: handle multiple interdependent transformations
@@ -242,10 +107,23 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,
       unsigned OpIdx = UserOp.getOperandNo();
 
       switch (UserMI->getOpcode()) {
-      default:
-        if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
-          break;
+      default: {
+        if (const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
+                RISCVVPseudosTable::getPseudoInfo(UserMI->getOpcode())) {
+          const MCInstrDesc &MCID = UserMI->getDesc();
+          if (!RISCVII::hasSEWOp(MCID.TSFlags))
+            return false;
+          assert(RISCVII::hasVLOp(MCID.TSFlags));
+          const unsigned Log2SEW =
+              UserMI->getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
+          if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
+            return false;
+          if (RISCVII::vectorInstUsesNBitsOfScalarOp(PseudoInfo->BaseInstr,
+                                                     Bits, Log2SEW))
+            break;
+        }
         return false;
+      }
 
       case RISCV::ADDIW:
       case RISCV::ADDW:

>From 050cfe347495744df94e065535fb206409956e95 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 27 Sep 2023 21:03:31 +0100
Subject: [PATCH 2/2] Move to RISCVInstrInfo and rework to return number of
 demanded bits, restore helper functions

---
 .../RISCV/MCTargetDesc/RISCVBaseInfo.cpp      | 112 ------------------
 .../Target/RISCV/MCTargetDesc/RISCVBaseInfo.h |   5 -
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp   |  62 ++++++----
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 112 ++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVInstrInfo.h        |   6 +
 llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp     |  46 ++++---
 6 files changed, 185 insertions(+), 158 deletions(-)

diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
index 95cea0c61acfd5d..0a42c6faee29008 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
@@ -130,118 +130,6 @@ parseFeatureBits(bool IsRV64, const FeatureBitset &FeatureBits) {
 
 } // namespace RISCVFeatures
 
-bool RISCVII::vectorInstUsesNBitsOfScalarOp(uint16_t Opcode, unsigned Bits,
-                                            unsigned Log2SEW) {
-  // TODO: Handle Zvbb instructions
-  switch (Opcode) {
-  default:
-    return false;
-
-  // 11.6. Vector Single-Width Shift Instructions
-  case RISCV::VSLL_VX:
-  case RISCV::VSRL_VX:
-  case RISCV::VSRA_VX:
-  // 12.4. Vector Single-Width Scaling Shift Instructions
-  case RISCV::VSSRL_VX:
-  case RISCV::VSSRA_VX:
-    // Only the low lg2(SEW) bits of the shift-amount value are used.
-    return Log2SEW <= Bits;
-
-  // 11.7 Vector Narrowing Integer Right Shift Instructions
-  case RISCV::VNSRL_WX:
-  case RISCV::VNSRA_WX:
-  // 12.5. Vector Narrowing Fixed-Point Clip Instructions
-  case RISCV::VNCLIPU_WX:
-  case RISCV::VNCLIP_WX:
-    // Only the low lg2(2*SEW) bits of the shift-amount value are used.
-    return (Log2SEW + 1) <= Bits;
-
-  // 11.1. Vector Single-Width Integer Add and Subtract
-  case RISCV::VADD_VX:
-  case RISCV::VSUB_VX:
-  case RISCV::VRSUB_VX:
-  // 11.2. Vector Widening Integer Add/Subtract
-  case RISCV::VWADDU_VX:
-  case RISCV::VWSUBU_VX:
-  case RISCV::VWADD_VX:
-  case RISCV::VWSUB_VX:
-  case RISCV::VWADDU_WX:
-  case RISCV::VWSUBU_WX:
-  case RISCV::VWADD_WX:
-  case RISCV::VWSUB_WX:
-  // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
-  case RISCV::VADC_VXM:
-  case RISCV::VADC_VIM:
-  case RISCV::VMADC_VXM:
-  case RISCV::VMADC_VIM:
-  case RISCV::VMADC_VX:
-  case RISCV::VSBC_VXM:
-  case RISCV::VMSBC_VXM:
-  case RISCV::VMSBC_VX:
-  // 11.5 Vector Bitwise Logical Instructions
-  case RISCV::VAND_VX:
-  case RISCV::VOR_VX:
-  case RISCV::VXOR_VX:
-  // 11.8. Vector Integer Compare Instructions
-  case RISCV::VMSEQ_VX:
-  case RISCV::VMSNE_VX:
-  case RISCV::VMSLTU_VX:
-  case RISCV::VMSLT_VX:
-  case RISCV::VMSLEU_VX:
-  case RISCV::VMSLE_VX:
-  case RISCV::VMSGTU_VX:
-  case RISCV::VMSGT_VX:
-  // 11.9. Vector Integer Min/Max Instructions
-  case RISCV::VMINU_VX:
-  case RISCV::VMIN_VX:
-  case RISCV::VMAXU_VX:
-  case RISCV::VMAX_VX:
-  // 11.10. Vector Single-Width Integer Multiply Instructions
-  case RISCV::VMUL_VX:
-  case RISCV::VMULH_VX:
-  case RISCV::VMULHU_VX:
-  case RISCV::VMULHSU_VX:
-  // 11.11. Vector Integer Divide Instructions
-  case RISCV::VDIVU_VX:
-  case RISCV::VDIV_VX:
-  case RISCV::VREMU_VX:
-  case RISCV::VREM_VX:
-  // 11.12. Vector Widening Integer Multiply Instructions
-  case RISCV::VWMUL_VX:
-  case RISCV::VWMULU_VX:
-  case RISCV::VWMULSU_VX:
-  // 11.13. Vector Single-Width Integer Multiply-Add Instructions
-  case RISCV::VMACC_VX:
-  case RISCV::VNMSAC_VX:
-  case RISCV::VMADD_VX:
-  case RISCV::VNMSUB_VX:
-  // 11.14. Vector Widening Integer Multiply-Add Instructions
-  case RISCV::VWMACCU_VX:
-  case RISCV::VWMACC_VX:
-  case RISCV::VWMACCSU_VX:
-  case RISCV::VWMACCUS_VX:
-  // 11.15. Vector Integer Merge Instructions
-  case RISCV::VMERGE_VXM:
-  // 11.16. Vector Integer Move Instructions
-  case RISCV::VMV_V_X:
-  // 12.1. Vector Single-Width Saturating Add and Subtract
-  case RISCV::VSADDU_VX:
-  case RISCV::VSADD_VX:
-  case RISCV::VSSUBU_VX:
-  case RISCV::VSSUB_VX:
-  // 12.2. Vector Single-Width Averaging Add and Subtract
-  case RISCV::VAADDU_VX:
-  case RISCV::VAADD_VX:
-  case RISCV::VASUBU_VX:
-  case RISCV::VASUB_VX:
-  // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
-  case RISCV::VSMUL_VX:
-  // 16.1. Integer Scalar Move Instructions
-  case RISCV::VMV_S_X:
-    return (1 << Log2SEW) <= Bits;
-  }
-}
-
 // Encode VTYPE into the binary format used by the the VSETVLI instruction which
 // is used by our MC layer representation.
 //
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 222d4e9eef674ec..20ff26a39dc3b30 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -241,11 +241,6 @@ static inline bool isFirstDefTiedToFirstUse(const MCInstrDesc &Desc) {
          Desc.getOperandConstraint(Desc.getNumDefs(), MCOI::TIED_TO) == 0;
 }
 
-// Returns true if the .vx vector instruction \p Opcode only uses the lower \p
-// Bits for a given SEW.
-bool vectorInstUsesNBitsOfScalarOp(uint16_t Opcode, unsigned Bits,
-                                   unsigned Log2SEW);
-
 // RISC-V Specific Machine Operand Flags
 enum {
   MO_None = 0,
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index b5f91c6bf70b061..65afb918f8b14c5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -2753,6 +2753,40 @@ bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt,
   return false;
 }
 
+static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,
+                                        unsigned Bits,
+                                        const TargetInstrInfo *TII) {
+  const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
+      RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode());
+
+  if (!PseudoInfo)
+    return false;
+
+  const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
+  const uint64_t TSFlags = MCID.TSFlags;
+  if (!RISCVII::hasSEWOp(TSFlags))
+    return false;
+  assert(RISCVII::hasVLOp(TSFlags));
+
+  bool HasGlueOp = User->getGluedNode() != nullptr;
+  unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
+  bool HasChainOp = User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
+  bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
+  unsigned VLIdx =
+      User->getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
+  const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
+
+  if (UserOpNo == VLIdx)
+    return false;
+
+  auto NumDemandedBits =
+      RISCV::getVectorLowDemandedScalarBits(PseudoInfo->BaseInstr, Log2SEW);
+  if (!NumDemandedBits || Bits < NumDemandedBits)
+    return false;
+
+  return true;
+}
+
 // Return true if all users of this SDNode* only consume the lower \p Bits.
 // This can be used to form W instructions for add/sub/mul/shl even when the
 // root isn't a sext_inreg. This can allow the ADDW/SUBW/MULW/SLLIW to CSE if
@@ -2783,32 +2817,10 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits,
 
     // TODO: Add more opcodes?
     switch (User->getMachineOpcode()) {
-    default: {
-      if (const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
-              RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode())) {
-
-        const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
-        if (!RISCVII::hasSEWOp(MCID.TSFlags))
-          return false;
-        assert(RISCVII::hasVLOp(MCID.TSFlags));
-
-        bool HasGlueOp = User->getGluedNode() != nullptr;
-        unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
-        bool HasChainOp =
-            User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
-        bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags);
-        unsigned VLIdx = User->getNumOperands() - HasVecPolicyOp - HasChainOp -
-                         HasGlueOp - 2;
-        const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
-
-        if (UI.getOperandNo() == VLIdx)
-          return false;
-        if (RISCVII::vectorInstUsesNBitsOfScalarOp(PseudoInfo->BaseInstr, Bits,
-                                                   Log2SEW))
-          break;
-      }
+    default:
+      if (vectorPseudoHasAllNBitUsers(User, UI.getOperandNo(), Bits, TII))
+        break;
       return false;
-    }
     case RISCV::ADDW:
     case RISCV::ADDIW:
     case RISCV::SUBW:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 6ee5e2d4c584049..59f6fb4205daa7a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2847,3 +2847,115 @@ bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) {
   MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx);
   return FrmOp1.getImm() == FrmOp2.getImm();
 }
+
+std::optional<unsigned>
+RISCV::getVectorLowDemandedScalarBits(uint16_t Opcode, unsigned Log2SEW) {
+  // TODO: Handle Zvbb instructions
+  switch (Opcode) {
+  default:
+    return std::nullopt;
+
+  // 11.6. Vector Single-Width Shift Instructions
+  case RISCV::VSLL_VX:
+  case RISCV::VSRL_VX:
+  case RISCV::VSRA_VX:
+  // 12.4. Vector Single-Width Scaling Shift Instructions
+  case RISCV::VSSRL_VX:
+  case RISCV::VSSRA_VX:
+    // Only the low lg2(SEW) bits of the shift-amount value are used.
+    return Log2SEW;
+
+  // 11.7 Vector Narrowing Integer Right Shift Instructions
+  case RISCV::VNSRL_WX:
+  case RISCV::VNSRA_WX:
+  // 12.5. Vector Narrowing Fixed-Point Clip Instructions
+  case RISCV::VNCLIPU_WX:
+  case RISCV::VNCLIP_WX:
+    // Only the low lg2(2*SEW) bits of the shift-amount value are used.
+    return Log2SEW + 1;
+
+  // 11.1. Vector Single-Width Integer Add and Subtract
+  case RISCV::VADD_VX:
+  case RISCV::VSUB_VX:
+  case RISCV::VRSUB_VX:
+  // 11.2. Vector Widening Integer Add/Subtract
+  case RISCV::VWADDU_VX:
+  case RISCV::VWSUBU_VX:
+  case RISCV::VWADD_VX:
+  case RISCV::VWSUB_VX:
+  case RISCV::VWADDU_WX:
+  case RISCV::VWSUBU_WX:
+  case RISCV::VWADD_WX:
+  case RISCV::VWSUB_WX:
+  // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
+  case RISCV::VADC_VXM:
+  case RISCV::VADC_VIM:
+  case RISCV::VMADC_VXM:
+  case RISCV::VMADC_VIM:
+  case RISCV::VMADC_VX:
+  case RISCV::VSBC_VXM:
+  case RISCV::VMSBC_VXM:
+  case RISCV::VMSBC_VX:
+  // 11.5 Vector Bitwise Logical Instructions
+  case RISCV::VAND_VX:
+  case RISCV::VOR_VX:
+  case RISCV::VXOR_VX:
+  // 11.8. Vector Integer Compare Instructions
+  case RISCV::VMSEQ_VX:
+  case RISCV::VMSNE_VX:
+  case RISCV::VMSLTU_VX:
+  case RISCV::VMSLT_VX:
+  case RISCV::VMSLEU_VX:
+  case RISCV::VMSLE_VX:
+  case RISCV::VMSGTU_VX:
+  case RISCV::VMSGT_VX:
+  // 11.9. Vector Integer Min/Max Instructions
+  case RISCV::VMINU_VX:
+  case RISCV::VMIN_VX:
+  case RISCV::VMAXU_VX:
+  case RISCV::VMAX_VX:
+  // 11.10. Vector Single-Width Integer Multiply Instructions
+  case RISCV::VMUL_VX:
+  case RISCV::VMULH_VX:
+  case RISCV::VMULHU_VX:
+  case RISCV::VMULHSU_VX:
+  // 11.11. Vector Integer Divide Instructions
+  case RISCV::VDIVU_VX:
+  case RISCV::VDIV_VX:
+  case RISCV::VREMU_VX:
+  case RISCV::VREM_VX:
+  // 11.12. Vector Widening Integer Multiply Instructions
+  case RISCV::VWMUL_VX:
+  case RISCV::VWMULU_VX:
+  case RISCV::VWMULSU_VX:
+  // 11.13. Vector Single-Width Integer Multiply-Add Instructions
+  case RISCV::VMACC_VX:
+  case RISCV::VNMSAC_VX:
+  case RISCV::VMADD_VX:
+  case RISCV::VNMSUB_VX:
+  // 11.14. Vector Widening Integer Multiply-Add Instructions
+  case RISCV::VWMACCU_VX:
+  case RISCV::VWMACC_VX:
+  case RISCV::VWMACCSU_VX:
+  case RISCV::VWMACCUS_VX:
+  // 11.15. Vector Integer Merge Instructions
+  case RISCV::VMERGE_VXM:
+  // 11.16. Vector Integer Move Instructions
+  case RISCV::VMV_V_X:
+  // 12.1. Vector Single-Width Saturating Add and Subtract
+  case RISCV::VSADDU_VX:
+  case RISCV::VSADD_VX:
+  case RISCV::VSSUBU_VX:
+  case RISCV::VSSUB_VX:
+  // 12.2. Vector Single-Width Averaging Add and Subtract
+  case RISCV::VAADDU_VX:
+  case RISCV::VAADD_VX:
+  case RISCV::VASUBU_VX:
+  case RISCV::VASUB_VX:
+  // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
+  case RISCV::VSMUL_VX:
+  // 16.1. Integer Scalar Move Instructions
+  case RISCV::VMV_S_X:
+    return 1 << Log2SEW;
+  }
+}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 99c907a98121ae3..d56d3c0b303bf91 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -265,6 +265,12 @@ int16_t getNamedOperandIdx(uint16_t Opcode, uint16_t NamedIndex);
 // one of the instructions does not have rounding mode, false will be returned.
 bool hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2);
 
+// If \p Opcode is a .vx vector instruction, returns the lower number of bits
+// that are used from the scalar .x operand for a given \p Log2SEW. Otherwise
+// returns null.
+std::optional<unsigned> getVectorLowDemandedScalarBits(uint16_t Opcode,
+                                                       unsigned Log2SEW);
+
 // Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
 static constexpr int64_t VLMaxSentinel = -1LL;
 
diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
index 56aa5589bc19a29..efc371498e75f11 100644
--- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
@@ -78,6 +78,33 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
   return new RISCVOptWInstrs();
 }
 
+static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
+                                        unsigned Bits) {
+  const MachineInstr &MI = *UserOp.getParent();
+  const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
+      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
+
+  if (!PseudoInfo)
+    return false;
+
+  const MCInstrDesc &MCID = MI.getDesc();
+  const uint64_t TSFlags = MCID.TSFlags;
+  if (!RISCVII::hasSEWOp(TSFlags))
+    return false;
+  assert(RISCVII::hasVLOp(TSFlags));
+  const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
+
+  if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
+    return false;
+
+  auto NumDemandedBits =
+      RISCV::getVectorLowDemandedScalarBits(PseudoInfo->BaseInstr, Log2SEW);
+  if (!NumDemandedBits || Bits < NumDemandedBits)
+    return false;
+
+  return true;
+}
+
 // Checks if all users only demand the lower \p OrigBits of the original
 // instruction's result.
 // TODO: handle multiple interdependent transformations
@@ -107,23 +134,10 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,
       unsigned OpIdx = UserOp.getOperandNo();
 
       switch (UserMI->getOpcode()) {
-      default: {
-        if (const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
-                RISCVVPseudosTable::getPseudoInfo(UserMI->getOpcode())) {
-          const MCInstrDesc &MCID = UserMI->getDesc();
-          if (!RISCVII::hasSEWOp(MCID.TSFlags))
-            return false;
-          assert(RISCVII::hasVLOp(MCID.TSFlags));
-          const unsigned Log2SEW =
-              UserMI->getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
-          if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
-            return false;
-          if (RISCVII::vectorInstUsesNBitsOfScalarOp(PseudoInfo->BaseInstr,
-                                                     Bits, Log2SEW))
-            break;
-        }
+      default:
+        if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
+          break;
         return false;
-      }
 
       case RISCV::ADDIW:
       case RISCV::ADDW:



More information about the llvm-commits mailing list