[llvm] [NVPTX] Use PRMT more widely, and improve folding around this instruction (PR #148261)

Kevin McAfee via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 11 11:15:47 PDT 2025


================
@@ -6334,3 +6309,45 @@ MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
     const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
   return getDataSection();
 }
+
+static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
+                                    const SelectionDAG &DAG, unsigned Depth) {
+  SDValue A = Op.getOperand(0);
+  SDValue B = Op.getOperand(1);
+  ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
+  unsigned Mode = Op.getConstantOperandVal(3);
+
+  if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
+    return;
+
+  KnownBits AKnown = DAG.computeKnownBits(A, Depth);
+  KnownBits BKnown = DAG.computeKnownBits(B, Depth);
+
+  // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
+  KnownBits BitField = BKnown.concat(AKnown);
+
+  APInt SelectorVal = Selector->getAPIntValue();
+  for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
+    APInt Sel = SelectorVal.extractBits(4, I * 4);
+    unsigned Idx = Sel.getLoBits(3).getZExtValue();
+    unsigned Sign = Sel.getHiBits(1).getZExtValue();
+    KnownBits Byte = BitField.extractBits(8, Idx * 8);
+    if (Sign)
+      Byte = KnownBits::ashr(Byte, 8);
+    Known.insertBits(Byte, I * 8);
+  }
+}
+
+void NVPTXTargetLowering::computeKnownBitsForTargetNode(
----------------
kalxr wrote:

Is this exercised by the tests?

https://github.com/llvm/llvm-project/pull/148261


More information about the llvm-commits mailing list