[llvm] [NVPTX] Add PRMT constant folding and cleanup usage of PRMT node (PR #148906)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 17 10:32:49 PDT 2025
================
@@ -5797,47 +5844,116 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}
+static APInt getPRMTSelector(APInt Selector, unsigned Mode) {
+ if (Mode == NVPTX::PTXPrmtMode::NONE)
+ return Selector;
+
+ unsigned V = Selector.trunc(2).getZExtValue();
+
+ const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
+ unsigned S3) {
+ return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
+ };
+
+ switch (Mode) {
+ case NVPTX::PTXPrmtMode::F4E:
+ return GetSelector(V, V + 1, V + 2, V + 3);
+ case NVPTX::PTXPrmtMode::B4E:
+ return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
+ case NVPTX::PTXPrmtMode::RC8:
+ return GetSelector(V, V, V, V);
+ case NVPTX::PTXPrmtMode::ECL:
+ return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
+ case NVPTX::PTXPrmtMode::ECR:
+ return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
+ case NVPTX::PTXPrmtMode::RC16: {
+ unsigned V1 = (V & 1) << 1;
+ return GetSelector(V1, V1 + 1, V1, V1 + 1);
+ }
+ default:
+ llvm_unreachable("Invalid PRMT mode");
+ }
+}
+
+static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
+ // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
+ APInt BitField = B.concat(A);
+ APInt SelectorVal = getPRMTSelector(Selector, Mode);
+ APInt Result(32, 0);
+ for (unsigned I : llvm::seq(4U)) {
+ APInt Sel = SelectorVal.extractBits(4, I * 4);
+ unsigned Idx = Sel.getLoBits(3).getZExtValue();
+ unsigned Sign = Sel.getHiBits(1).getZExtValue();
+ APInt Byte = BitField.extractBits(8, Idx * 8);
+ if (Sign)
+ Byte = Byte.ashr(8);
+ Result.insertBits(Byte, I * 8);
+ }
+ return Result;
+}
+
+static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ // Constant fold PRMT
+ if (isa<ConstantSDNode>(N->getOperand(0)) &&
+ isa<ConstantSDNode>(N->getOperand(1)) &&
+ isa<ConstantSDNode>(N->getOperand(2)))
+ return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
+ N->getConstantOperandAPInt(1),
+ N->getConstantOperandAPInt(2),
+ N->getConstantOperandVal(3)),
+ SDLoc(N), N->getValueType(0));
+
+ return SDValue();
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
switch (N->getOpcode()) {
- default: break;
- case ISD::ADD:
- return PerformADDCombine(N, DCI, OptLevel);
- case ISD::FADD:
- return PerformFADDCombine(N, DCI, OptLevel);
- case ISD::MUL:
- return PerformMULCombine(N, DCI, OptLevel);
- case ISD::SHL:
- return PerformSHLCombine(N, DCI, OptLevel);
- case ISD::AND:
- return PerformANDCombine(N, DCI);
- case ISD::UREM:
- case ISD::SREM:
- return PerformREMCombine(N, DCI, OptLevel);
- case ISD::SETCC:
- return PerformSETCCCombine(N, DCI, STI.getSmVersion());
- case ISD::LOAD:
- case NVPTXISD::LoadParamV2:
- case NVPTXISD::LoadV2:
- case NVPTXISD::LoadV4:
- return combineUnpackingMovIntoLoad(N, DCI);
- case NVPTXISD::StoreParam:
- case NVPTXISD::StoreParamV2:
- case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N, DCI);
- case ISD::STORE:
- case NVPTXISD::StoreV2:
- case NVPTXISD::StoreV4:
- return PerformStoreCombine(N, DCI);
- case ISD::EXTRACT_VECTOR_ELT:
- return PerformEXTRACTCombine(N, DCI);
- case ISD::VSELECT:
- return PerformVSELECTCombine(N, DCI);
- case ISD::BUILD_VECTOR:
- return PerformBUILD_VECTORCombine(N, DCI);
- case ISD::ADDRSPACECAST:
- return combineADDRSPACECAST(N, DCI);
+ default:
+ break;
+ case ISD::ADD:
+ return PerformADDCombine(N, DCI, OptLevel);
+ case ISD::FADD:
+ return PerformFADDCombine(N, DCI, OptLevel);
+ case ISD::MUL:
+ return PerformMULCombine(N, DCI, OptLevel);
+ case ISD::SHL:
+ return PerformSHLCombine(N, DCI, OptLevel);
+ case ISD::AND:
+ return PerformANDCombine(N, DCI);
+ case ISD::UREM:
+ case ISD::SREM:
+ return PerformREMCombine(N, DCI, OptLevel);
+ case ISD::SETCC:
+ return PerformSETCCCombine(N, DCI, STI.getSmVersion());
+ case ISD::LOAD:
+ case NVPTXISD::LoadParamV2:
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4:
+ return combineUnpackingMovIntoLoad(N, DCI);
+ case NVPTXISD::StoreParam:
+ case NVPTXISD::StoreParamV2:
+ case NVPTXISD::StoreParamV4:
+ return PerformStoreParamCombine(N, DCI);
+ case ISD::STORE:
+ case NVPTXISD::StoreV2:
+ case NVPTXISD::StoreV4:
+ return PerformStoreCombine(N, DCI);
+ case ISD::EXTRACT_VECTOR_ELT:
+ return PerformEXTRACTCombine(N, DCI);
+ case ISD::VSELECT:
+ return PerformVSELECTCombine(N, DCI);
+ case ISD::BUILD_VECTOR:
+ return PerformBUILD_VECTORCombine(N, DCI);
+ case ISD::ADDRSPACECAST:
+ return combineADDRSPACECAST(N, DCI);
+ case NVPTXISD::PRMT:
+ return combinePRMT(N, DCI, OptLevel);
----------------
AlexMaclean wrote:
Done
https://github.com/llvm/llvm-project/pull/148906
More information about the llvm-commits
mailing list