[llvm] [NVPTX] Add PRMT constant folding and cleanup usage of PRMT node (PR #148906)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 17 10:32:49 PDT 2025


================
@@ -5797,47 +5844,116 @@ static SDValue combineADDRSPACECAST(SDNode *N,
   return SDValue();
 }
 
+static APInt getPRMTSelector(APInt Selector, unsigned Mode) {
+  if (Mode == NVPTX::PTXPrmtMode::NONE)
+    return Selector;
+
+  unsigned V = Selector.trunc(2).getZExtValue();
+
+  const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
+                              unsigned S3) {
+    return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
+  };
+
+  switch (Mode) {
+  case NVPTX::PTXPrmtMode::F4E:
+    return GetSelector(V, V + 1, V + 2, V + 3);
+  case NVPTX::PTXPrmtMode::B4E:
+    return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
+  case NVPTX::PTXPrmtMode::RC8:
+    return GetSelector(V, V, V, V);
+  case NVPTX::PTXPrmtMode::ECL:
+    return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
+  case NVPTX::PTXPrmtMode::ECR:
+    return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
+  case NVPTX::PTXPrmtMode::RC16: {
+    unsigned V1 = (V & 1) << 1;
+    return GetSelector(V1, V1 + 1, V1, V1 + 1);
+  }
+  default:
+    llvm_unreachable("Invalid PRMT mode");
+  }
+}
+
+static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
+  // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
+  APInt BitField = B.concat(A);
+  APInt SelectorVal = getPRMTSelector(Selector, Mode);
+  APInt Result(32, 0);
+  for (unsigned I : llvm::seq(4U)) {
+    APInt Sel = SelectorVal.extractBits(4, I * 4);
+    unsigned Idx = Sel.getLoBits(3).getZExtValue();
+    unsigned Sign = Sel.getHiBits(1).getZExtValue();
+    APInt Byte = BitField.extractBits(8, Idx * 8);
+    if (Sign)
+      Byte = Byte.ashr(8);
+    Result.insertBits(Byte, I * 8);
+  }
+  return Result;
+}
+
+static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                           CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  // Constant fold PRMT
+  if (isa<ConstantSDNode>(N->getOperand(0)) &&
+      isa<ConstantSDNode>(N->getOperand(1)) &&
+      isa<ConstantSDNode>(N->getOperand(2)))
+    return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
+                                           N->getConstantOperandAPInt(1),
+                                           N->getConstantOperandAPInt(2),
+                                           N->getConstantOperandVal(3)),
+                               SDLoc(N), N->getValueType(0));
+
+  return SDValue();
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
   switch (N->getOpcode()) {
-    default: break;
-    case ISD::ADD:
-      return PerformADDCombine(N, DCI, OptLevel);
-    case ISD::FADD:
-      return PerformFADDCombine(N, DCI, OptLevel);
-    case ISD::MUL:
-      return PerformMULCombine(N, DCI, OptLevel);
-    case ISD::SHL:
-      return PerformSHLCombine(N, DCI, OptLevel);
-    case ISD::AND:
-      return PerformANDCombine(N, DCI);
-    case ISD::UREM:
-    case ISD::SREM:
-      return PerformREMCombine(N, DCI, OptLevel);
-    case ISD::SETCC:
-      return PerformSETCCCombine(N, DCI, STI.getSmVersion());
-    case ISD::LOAD:
-    case NVPTXISD::LoadParamV2:
-    case NVPTXISD::LoadV2:
-    case NVPTXISD::LoadV4:
-      return combineUnpackingMovIntoLoad(N, DCI);
-    case NVPTXISD::StoreParam:
-    case NVPTXISD::StoreParamV2:
-    case NVPTXISD::StoreParamV4:
-      return PerformStoreParamCombine(N, DCI);
-    case ISD::STORE:
-    case NVPTXISD::StoreV2:
-    case NVPTXISD::StoreV4:
-      return PerformStoreCombine(N, DCI);
-    case ISD::EXTRACT_VECTOR_ELT:
-      return PerformEXTRACTCombine(N, DCI);
-    case ISD::VSELECT:
-      return PerformVSELECTCombine(N, DCI);
-    case ISD::BUILD_VECTOR:
-      return PerformBUILD_VECTORCombine(N, DCI);
-    case ISD::ADDRSPACECAST:
-      return combineADDRSPACECAST(N, DCI);
+  default:
+    break;
+  case ISD::ADD:
+    return PerformADDCombine(N, DCI, OptLevel);
+  case ISD::FADD:
+    return PerformFADDCombine(N, DCI, OptLevel);
+  case ISD::MUL:
+    return PerformMULCombine(N, DCI, OptLevel);
+  case ISD::SHL:
+    return PerformSHLCombine(N, DCI, OptLevel);
+  case ISD::AND:
+    return PerformANDCombine(N, DCI);
+  case ISD::UREM:
+  case ISD::SREM:
+    return PerformREMCombine(N, DCI, OptLevel);
+  case ISD::SETCC:
+    return PerformSETCCCombine(N, DCI, STI.getSmVersion());
+  case ISD::LOAD:
+  case NVPTXISD::LoadParamV2:
+  case NVPTXISD::LoadV2:
+  case NVPTXISD::LoadV4:
+    return combineUnpackingMovIntoLoad(N, DCI);
+  case NVPTXISD::StoreParam:
+  case NVPTXISD::StoreParamV2:
+  case NVPTXISD::StoreParamV4:
+    return PerformStoreParamCombine(N, DCI);
+  case ISD::STORE:
+  case NVPTXISD::StoreV2:
+  case NVPTXISD::StoreV4:
+    return PerformStoreCombine(N, DCI);
+  case ISD::EXTRACT_VECTOR_ELT:
+    return PerformEXTRACTCombine(N, DCI);
+  case ISD::VSELECT:
+    return PerformVSELECTCombine(N, DCI);
+  case ISD::BUILD_VECTOR:
+    return PerformBUILD_VECTORCombine(N, DCI);
+  case ISD::ADDRSPACECAST:
+    return combineADDRSPACECAST(N, DCI);
+  case NVPTXISD::PRMT:
+    return combinePRMT(N, DCI, OptLevel);
----------------
AlexMaclean wrote:

Done

https://github.com/llvm/llvm-project/pull/148906


More information about the llvm-commits mailing list