[llvm] [NVPTX] support immediate values in st.param instructions (PR #91523)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Wed May 8 14:59:49 PDT 2024
================
@@ -2182,6 +2182,84 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
return true;
}
+// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
+#define getOpcV2H(ty, op0, op1) NVPTX::StoreParamV2##ty##_##op0##op1
+
+#define getOpcV2H1(ty, op0, op1) \
+ (op1) ? getOpcV2H(ty, op0, i) : getOpcV2H(ty, op0, r)
+
+#define getOpcodeForVectorStParamV2(ty, isimm) \
+ (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
+
+#define getOpcV4H(ty, op0, op1, op2, op3) \
+ NVPTX::StoreParamV4##ty##_##op0##op1##op2##op3
+
+#define getOpcV4H3(ty, op0, op1, op2, op3) \
+ (op3) ? getOpcV4H(ty, op0, op1, op2, i) : getOpcV4H(ty, op0, op1, op2, r)
+
+#define getOpcV4H2(ty, op0, op1, op2, op3) \
+ (op2) ? getOpcV4H3(ty, op0, op1, i, op3) : getOpcV4H3(ty, op0, op1, r, op3)
+
+#define getOpcV4H1(ty, op0, op1, op2, op3) \
+ (op1) ? getOpcV4H2(ty, op0, i, op2, op3) : getOpcV4H2(ty, op0, r, op2, op3)
+
+#define getOpcodeForVectorStParamV4(ty, isimm) \
+ (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
+ : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
+
+#define getOpcodeForVectorStParam(n, ty, isimm) \
+ (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
+ : getOpcodeForVectorStParamV4(ty, isimm)
+
+static std::optional<unsigned>
+pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops, unsigned NumElts,
+ MVT::SimpleValueType MemTy, SelectionDAG *CurDAG,
+ SDLoc DL) {
+ // Determine which inputs are registers and immediates make new operators
+ // with constant values
+ SmallVector<bool, 4> IsImm(NumElts, false);
+ for (unsigned i = 0; i < NumElts; i++) {
+ IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
+ if (IsImm[i]) {
+ SDValue Imm = Ops[i];
+ if (MemTy == MVT::f32 || MemTy == MVT::f64) {
+ const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
+ const ConstantFP *CF = ConstImm->getConstantFPValue();
+ Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
+ } else {
+ const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
+ const ConstantInt *CI = ConstImm->getConstantIntValue();
+ Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
+ }
+ Ops[i] = Imm;
+ }
+ }
+
+ // Get opcode for MemTy, size, and register/immediate operand ordering
+ switch (MemTy) {
+ case MVT::i8:
+ return getOpcodeForVectorStParam(NumElts, I8, IsImm);
+ case MVT::i16:
+ return getOpcodeForVectorStParam(NumElts, I16, IsImm);
+ case MVT::i32:
+ return getOpcodeForVectorStParam(NumElts, I32, IsImm);
+ case MVT::i64:
+ if (NumElts == 4)
+ return std::nullopt;
+ return getOpcodeForVectorStParamV2(I64, IsImm);
+ case MVT::f32:
+ return getOpcodeForVectorStParam(NumElts, F32, IsImm);
+ case MVT::f64:
+ if (NumElts == 4)
+ return std::nullopt;
+ return getOpcodeForVectorStParamV2(F64, IsImm);
+ case MVT::f16:
+ case MVT::v2f16:
----------------
Artem-B wrote:
is seeing f16 types here an error? Or business as usual? If that's an error then we should have llvm_unreachable or an assert here. Considering that those are loaded/stored as i16, I suspect it should be an error.
https://github.com/llvm/llvm-project/pull/91523
More information about the llvm-commits
mailing list