[llvm] d7f77b2 - [NVPTX] Cleanup various vestigial elements and fold together more table-gen (NFC) (#151447)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 31 13:50:58 PDT 2025
Author: Alex MacLean
Date: 2025-07-31T13:50:56-07:00
New Revision: d7f77b2e82aefa575184ecafe7f67278560bd18d
URL: https://github.com/llvm/llvm-project/commit/d7f77b2e82aefa575184ecafe7f67278560bd18d
DIFF: https://github.com/llvm/llvm-project/commit/d7f77b2e82aefa575184ecafe7f67278560bd18d.diff
LOG: [NVPTX] Cleanup various vestigial elements and fold together more table-gen (NFC) (#151447)
Added:
Modified:
llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 8eec91562ecfe..ee1ca4538554b 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -391,16 +391,6 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
-void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum,
- raw_ostream &O) {
- auto &Op = MI->getOperand(OpNum);
- assert(Op.isImm() && "Invalid operand");
- if (Op.getImm() != 0) {
- O << "+";
- printOperand(MI, OpNum, O);
- }
-}
-
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index c3ff3469150e4..92155b01464e8 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,7 +46,6 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
- void printOffseti32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index cd404819cb837..a3496090def3c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -56,15 +56,12 @@ static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
case NVPTX::LD_i16:
case NVPTX::LD_i32:
case NVPTX::LD_i64:
- case NVPTX::LD_i8:
case NVPTX::LDV_i16_v2:
case NVPTX::LDV_i16_v4:
case NVPTX::LDV_i32_v2:
case NVPTX::LDV_i32_v4:
case NVPTX::LDV_i64_v2:
- case NVPTX::LDV_i64_v4:
- case NVPTX::LDV_i8_v2:
- case NVPTX::LDV_i8_v4: {
+ case NVPTX::LDV_i64_v4: {
LoadInsts.push_back(&U);
return true;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 95abcded46485..6068035b2ee47 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1003,14 +1003,10 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
// Helper function template to reduce amount of boilerplate code for
// opcode selection.
static std::optional<unsigned>
-pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
- std::optional<unsigned> Opcode_i16,
+pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16,
std::optional<unsigned> Opcode_i32,
std::optional<unsigned> Opcode_i64) {
switch (VT) {
- case MVT::i1:
- case MVT::i8:
- return Opcode_i8;
case MVT::f16:
case MVT::i16:
case MVT::bf16:
@@ -1078,8 +1074,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
Chain};
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
- const std::optional<unsigned> Opcode = pickOpcodeForVT(
- TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64);
+ const std::optional<unsigned> Opcode =
+ pickOpcodeForVT(TargetVT, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64);
if (!Opcode)
return false;
@@ -1164,17 +1160,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
default:
llvm_unreachable("Unexpected opcode");
case NVPTXISD::LoadV2:
- Opcode =
- pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v2, NVPTX::LDV_i16_v2,
- NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2);
+ Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v2,
+ NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2);
break;
case NVPTXISD::LoadV4:
- Opcode =
- pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v4, NVPTX::LDV_i16_v4,
- NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4);
+ Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v4,
+ NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4);
break;
case NVPTXISD::LoadV8:
- Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i8 */}, {/* no v8i16 */},
+ Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i16 */},
NVPTX::LDV_i32_v8, {/* no v8i64 */});
break;
}
@@ -1230,22 +1224,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
default:
llvm_unreachable("Unexpected opcode");
case ISD::LOAD:
- Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i8,
- NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32,
- NVPTX::LD_GLOBAL_NC_i64);
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
+ NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
break;
case NVPTXISD::LoadV2:
- Opcode = pickOpcodeForVT(
- TargetVT, NVPTX::LD_GLOBAL_NC_v2i8, NVPTX::LD_GLOBAL_NC_v2i16,
- NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64);
+ Opcode =
+ pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,
+ NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64);
break;
case NVPTXISD::LoadV4:
- Opcode = pickOpcodeForVT(
- TargetVT, NVPTX::LD_GLOBAL_NC_v4i8, NVPTX::LD_GLOBAL_NC_v4i16,
- NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64);
+ Opcode =
+ pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v4i16,
+ NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64);
break;
case NVPTXISD::LoadV8:
- Opcode = pickOpcodeForVT(TargetVT, {/* no v8i8 */}, {/* no v8i16 */},
+ Opcode = pickOpcodeForVT(TargetVT, {/* no v8i16 */},
NVPTX::LD_GLOBAL_NC_v8i32, {/* no v8i64 */});
break;
}
@@ -1276,8 +1269,9 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
break;
}
- const MVT::SimpleValueType SelectVT =
- MVT::getIntegerVT(LD->getMemoryVT().getSizeInBits() / NumElts).SimpleTy;
+ SDLoc DL(N);
+ const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts;
+ const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
// If this is an LDU intrinsic, the address is the third operand. If its an
// LDU SD node (from custom vector handling), then its the second operand
@@ -1286,32 +1280,28 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
SDValue Base, Offset;
SelectADDR(Addr, Base, Offset);
- SDValue Ops[] = {Base, Offset, LD->getChain()};
+ SDValue Ops[] = {getI32Imm(FromTypeWidth, DL), Base, Offset, LD->getChain()};
std::optional<unsigned> Opcode;
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case ISD::INTRINSIC_W_CHAIN:
- Opcode =
- pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_i8, NVPTX::LDU_GLOBAL_i16,
- NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64);
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_i16,
+ NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64);
break;
case NVPTXISD::LDUV2:
- Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v2i8,
- NVPTX::LDU_GLOBAL_v2i16, NVPTX::LDU_GLOBAL_v2i32,
- NVPTX::LDU_GLOBAL_v2i64);
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v2i16,
+ NVPTX::LDU_GLOBAL_v2i32, NVPTX::LDU_GLOBAL_v2i64);
break;
case NVPTXISD::LDUV4:
- Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v4i8,
- NVPTX::LDU_GLOBAL_v4i16, NVPTX::LDU_GLOBAL_v4i32,
- {/* no v4i64 */});
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v4i16,
+ NVPTX::LDU_GLOBAL_v4i32, {/* no v4i64 */});
break;
}
if (!Opcode)
return false;
- SDLoc DL(N);
SDNode *NVPTXLDU = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops);
ReplaceNode(LD, NVPTXLDU);
@@ -1362,8 +1352,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
Chain};
const std::optional<unsigned> Opcode =
- pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8,
- NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64);
+ pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i16,
+ NVPTX::ST_i32, NVPTX::ST_i64);
if (!Opcode)
return false;
@@ -1423,16 +1413,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
default:
return false;
case NVPTXISD::StoreV2:
- Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v2, NVPTX::STV_i16_v2,
- NVPTX::STV_i32_v2, NVPTX::STV_i64_v2);
+ Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v2, NVPTX::STV_i32_v2,
+ NVPTX::STV_i64_v2);
break;
case NVPTXISD::StoreV4:
- Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v4, NVPTX::STV_i16_v4,
- NVPTX::STV_i32_v4, NVPTX::STV_i64_v4);
+ Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v4, NVPTX::STV_i32_v4,
+ NVPTX::STV_i64_v4);
break;
case NVPTXISD::StoreV8:
- Opcode = pickOpcodeForVT(EltVT, {/* no v8i8 */}, {/* no v8i16 */},
- NVPTX::STV_i32_v8, {/* no v8i64 */});
+ Opcode = pickOpcodeForVT(EltVT, {/* no v8i16 */}, NVPTX::STV_i32_v8,
+ {/* no v8i64 */});
break;
}
@@ -1687,10 +1677,11 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
auto API = APF.bitcastToAPInt();
API = API.concat(API);
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
- return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32i, DL, VT, Const), 0);
+ return SDValue(CurDAG->getMachineNode(NVPTX::MOV_B32_i, DL, VT, Const),
+ 0);
}
auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
- return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16i, DL, VT, Const), 0);
+ return SDValue(CurDAG->getMachineNode(NVPTX::MOV_BF16_i, DL, VT, Const), 0);
};
switch (N->getOpcode()) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
index 86dcb4a9384f1..719be0300940e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
@@ -11,15 +11,9 @@
//
//===----------------------------------------------------------------------===//
-// Vector instruction type enum
-class VecInstTypeEnum<bits<4> val> {
- bits<4> Value=val;
-}
-def VecNOP : VecInstTypeEnum<0>;
-
// Generic NVPTX Format
-class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
+class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Instruction {
field bits<14> Inst;
@@ -30,7 +24,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
let Pattern = pattern;
// TSFlagFields
- bits<4> VecInstType = VecNOP.Value;
bit IsLoad = false;
bit IsStore = false;
@@ -45,7 +38,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
// 2**(2-1) = 2.
bits<2> IsSuld = 0;
- let TSFlags{3...0} = VecInstType;
let TSFlags{4} = IsLoad;
let TSFlags{5} = IsStore;
let TSFlags{6} = IsTex;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index e218ef17bb09b..34fe467c94563 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -35,23 +35,23 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
const TargetRegisterClass *DestRC = MRI.getRegClass(DestReg);
const TargetRegisterClass *SrcRC = MRI.getRegClass(SrcReg);
- if (RegInfo.getRegSizeInBits(*DestRC) != RegInfo.getRegSizeInBits(*SrcRC))
+ if (DestRC != SrcRC)
report_fatal_error("Copy one register into another with a
diff erent width");
unsigned Op;
- if (DestRC == &NVPTX::B1RegClass) {
- Op = NVPTX::IMOV1r;
- } else if (DestRC == &NVPTX::B16RegClass) {
- Op = NVPTX::MOV16r;
- } else if (DestRC == &NVPTX::B32RegClass) {
- Op = NVPTX::IMOV32r;
- } else if (DestRC == &NVPTX::B64RegClass) {
- Op = NVPTX::IMOV64r;
- } else if (DestRC == &NVPTX::B128RegClass) {
- Op = NVPTX::IMOV128r;
- } else {
+ if (DestRC == &NVPTX::B1RegClass)
+ Op = NVPTX::MOV_B1_r;
+ else if (DestRC == &NVPTX::B16RegClass)
+ Op = NVPTX::MOV_B16_r;
+ else if (DestRC == &NVPTX::B32RegClass)
+ Op = NVPTX::MOV_B32_r;
+ else if (DestRC == &NVPTX::B64RegClass)
+ Op = NVPTX::MOV_B64_r;
+ else if (DestRC == &NVPTX::B128RegClass)
+ Op = NVPTX::MOV_B128_r;
+ else
llvm_unreachable("Bad register copy");
- }
+
BuildMI(MBB, I, DL, get(Op), DestReg)
.addReg(SrcReg, getKillRegState(KillSrc));
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6000b40694763..d8047d31ff6f0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td"
let OperandType = "OPERAND_IMMEDIATE" in {
def f16imm : Operand<f16>;
def bf16imm : Operand<bf16>;
-
}
-// List of vector specific properties
-def isVecLD : VecInstTypeEnum<1>;
-def isVecST : VecInstTypeEnum<2>;
-def isVecBuild : VecInstTypeEnum<3>;
-def isVecShuffle : VecInstTypeEnum<4>;
-def isVecExtract : VecInstTypeEnum<5>;
-def isVecInsert : VecInstTypeEnum<6>;
-def isVecDest : VecInstTypeEnum<7>;
-def isVecOther : VecInstTypeEnum<15>;
-
//===----------------------------------------------------------------------===//
// NVPTX Operand Definitions.
//===----------------------------------------------------------------------===//
@@ -484,46 +473,28 @@ let hasSideEffects = false in {
// takes a CvtMode immediate that defines the conversion mode to use. It can
// be CvtNONE to omit a conversion mode.
multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> {
- def _s8 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">,
- Requires<Preds>;
- def _u8 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">,
- Requires<Preds>;
- def _s16 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">,
- Requires<Preds>;
- def _u16 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">,
- Requires<Preds>;
- def _s32 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B32:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">,
- Requires<Preds>;
- def _u32 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B32:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">,
- Requires<Preds>;
- def _s64 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B64:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">,
- Requires<Preds>;
- def _u64 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B64:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">,
- Requires<Preds>;
+ foreach sign = ["s", "u"] in {
+ def _ # sign # "8" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B16:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">,
+ Requires<Preds>;
+ def _ # sign # "16" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B16:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">,
+ Requires<Preds>;
+ def _ # sign # "32" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">,
+ Requires<Preds>;
+ def _ # sign # "64" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B64:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">,
+ Requires<Preds>;
+ }
def _f16 :
BasicFlagsNVPTXInst<(outs RC:$dst),
(ins B16:$src), (ins CvtMode:$mode),
@@ -554,14 +525,12 @@ let hasSideEffects = false in {
}
// Generate cvts from all types to all types.
- defm CVT_s8 : CVT_FROM_ALL<"s8", B16>;
- defm CVT_u8 : CVT_FROM_ALL<"u8", B16>;
- defm CVT_s16 : CVT_FROM_ALL<"s16", B16>;
- defm CVT_u16 : CVT_FROM_ALL<"u16", B16>;
- defm CVT_s32 : CVT_FROM_ALL<"s32", B32>;
- defm CVT_u32 : CVT_FROM_ALL<"u32", B32>;
- defm CVT_s64 : CVT_FROM_ALL<"s64", B64>;
- defm CVT_u64 : CVT_FROM_ALL<"u64", B64>;
+ foreach sign = ["s", "u"] in {
+ defm CVT_ # sign # "8" : CVT_FROM_ALL<sign # "8", B16>;
+ defm CVT_ # sign # "16" : CVT_FROM_ALL<sign # "16", B16>;
+ defm CVT_ # sign # "32" : CVT_FROM_ALL<sign # "32", B32>;
+ defm CVT_ # sign # "64" : CVT_FROM_ALL<sign # "64", B64>;
+ }
defm CVT_f16 : CVT_FROM_ALL<"f16", B16>;
defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>;
defm CVT_f32 : CVT_FROM_ALL<"f32", B32>;
@@ -569,18 +538,12 @@ let hasSideEffects = false in {
// These cvts are
diff erent from those above: The source and dest registers
// are of the same type.
- def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src),
- "cvt.s16.s8">;
- def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "cvt.s32.s8">;
- def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "cvt.s32.s16">;
- def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s8">;
- def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s16">;
- def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s32">;
+ def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">;
+ def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">;
+ def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">;
+ def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">;
+ def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">;
+ def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">;
multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
@@ -782,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>;
def ADD16x2 : I16x2<"add.s", add>;
-// in32 and int64 addition and subtraction with carry-out.
+// int32 and int64 addition and subtraction with carry-out.
defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>;
defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>;
@@ -803,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>;
defm SREM : I3<"rem.s", srem, commutative = false>;
defm UREM : I3<"rem.u", urem, commutative = false>;
-// Integer absolute value. NumBits should be one minus the bit width of RC.
-// This idiom implements the algorithm at
-// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs.
-multiclass ABS<ValueType T, RegisterClass RC, string SizeName> {
- def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a),
- "abs" # SizeName,
- [(set T:$dst, (abs T:$a))]>;
+foreach t = [I16RT, I32RT, I64RT] in {
+ def ABS_S # t.Size :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a),
+ "abs.s" # t.Size,
+ [(set t.Ty:$dst, (abs t.Ty:$a))]>;
+
+ def NEG_S # t.Size :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
+ "neg.s" # t.Size,
+ [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
}
-defm ABS_16 : ABS<i16, B16, ".s16">;
-defm ABS_32 : ABS<i32, B32, ".s32">;
-defm ABS_64 : ABS<i64, B64, ".s64">;
// Integer min/max.
defm SMAX : I3<"max.s", smax, commutative = true>;
@@ -830,116 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>;
//
// Wide multiplication
//
-def MULWIDES64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">;
-def MULWIDES64Imm :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">;
-
-def MULWIDEU64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">;
-def MULWIDEU64Imm :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">;
-
-def MULWIDES32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">;
-def MULWIDES32Imm :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">;
-
-def MULWIDEU32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">;
-def MULWIDEU32Imm :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">;
def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
-def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>;
-def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>;
+def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>;
+def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>;
-// Matchers for signed, unsigned mul.wide ISD nodes.
-let Predicates = [hasOptEnabled] in {
- def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
- def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
- def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
- def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
- def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
- def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
- def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
- def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
+multiclass MULWIDEInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> {
+ def suffix # _rr :
+ BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b),
+ "mul.wide." # suffix,
+ [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>;
+ def suffix # _ri :
+ BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b),
+ "mul.wide." # suffix,
+ [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>;
}
+defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>;
+defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>;
+defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>;
+defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>;
+
//
// Integer multiply-add
//
-def mul_oneuse : OneUse2<mul>;
-
-multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> {
- def rrr:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Reg:$b, Reg:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>;
-
- def rir:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Imm:$b, Reg:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>;
- def rri:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Reg:$b, Imm:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>;
- def rii:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Imm:$b, Imm:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>;
-}
-
-let Predicates = [hasOptEnabled] in {
-defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>;
-defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
-defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
-}
-
-multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> {
+multiclass MADInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> {
def rrr:
- BasicNVPTXInst<(outs BigT.RC:$dst),
- (ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c),
- "mad.wide." # PtxSuffix,
- [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>;
def rri:
- BasicNVPTXInst<(outs BigT.RC:$dst),
- (ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c),
- "mad.wide." # PtxSuffix,
- [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>;
def rir:
- BasicNVPTXInst<(outs BigT.RC:$dst),
- (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c),
- "mad.wide." # PtxSuffix,
- [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>;
def rii:
- BasicNVPTXInst<(outs BigT.RC:$dst),
- (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c),
- "mad.wide." # PtxSuffix,
- [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), imm:$c))]>;
}
-def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>;
-def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>;
-
let Predicates = [hasOptEnabled] in {
-defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>;
-defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>;
-defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>;
-defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>;
-}
+ defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>;
+ defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>;
+ defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>;
-foreach t = [I16RT, I32RT, I64RT] in {
- def NEG_S # t.Size :
- BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
- "neg.s" # t.Size,
- [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
+ defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>;
+ defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>;
+ defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>;
+ defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>;
}
//-----------------------------------
@@ -1050,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b),
def FRCP32_approx_r :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$b), (ins FTZFlag:$ftz),
"rcp.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
@@ -1060,14 +969,12 @@ def FRCP32_approx_r :
//
def FDIV32_approx_rr :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
def FDIV32_approx_ri :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
//
@@ -1090,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b),
//
def FDIV32rr :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.full$ftz.f32",
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
def FDIV32ri :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.full$ftz.f32",
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
//
@@ -1111,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b),
def FRCP32r_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$b), (ins FTZFlag:$ftz),
"rcp.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>;
//
@@ -1120,14 +1024,12 @@ def FRCP32r_prec :
//
def FDIV32rr_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>;
def FDIV32ri_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>;
@@ -1206,10 +1108,8 @@ def TANH_APPROX_f32 :
// Template for three-arg bitwise operations. Takes three args, Creates .b16,
// .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr.
multiclass BITWISE<string OpcStr, SDNode OpNode> {
- defm b1 : I3Inst<OpcStr # ".pred", OpNode, I1RT, commutative = true>;
- defm b16 : I3Inst<OpcStr # ".b16", OpNode, I16RT, commutative = true>;
- defm b32 : I3Inst<OpcStr # ".b32", OpNode, I32RT, commutative = true>;
- defm b64 : I3Inst<OpcStr # ".b64", OpNode, I64RT, commutative = true>;
+ foreach t = [I1RT, I16RT, I32RT, I64RT] in
+ defm _ # t.PtxType : I3Inst<OpcStr # "." # t.PtxType, OpNode, t, commutative = true>;
}
defm OR : BITWISE<"or", or>;
@@ -1217,48 +1117,40 @@ defm AND : BITWISE<"and", and>;
defm XOR : BITWISE<"xor", xor>;
// PTX does not support mul on predicates, convert to and instructions
-def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>;
-def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>;
+def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>;
+def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>;
foreach op = [add, sub] in {
- def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>;
- def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>;
+ def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>;
+ def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>;
}
// These transformations were once reliably performed by instcombine, but thanks
// to poison semantics they are no longer safe for LLVM IR, perform them here
// instead.
-def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>;
-def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>;
+def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>;
+def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>;
// Lower logical v2i16/v4i8 ops as bitwise ops on b32.
foreach vt = [v2i16, v4i8] in {
- def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>;
- def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>;
- def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>;
+ def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>;
+ def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>;
+ def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>;
// The constants get legalized into a bitcast from i32, so that's what we need
// to match here.
def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (ORb32ri $a, imm:$b)>;
+ (OR_b32ri $a, imm:$b)>;
def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (XORb32ri $a, imm:$b)>;
+ (XOR_b32ri $a, imm:$b)>;
def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (ANDb32ri $a, imm:$b)>;
-}
-
-def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src),
- "not.pred",
- [(set i1:$dst, (not i1:$src))]>;
-def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src),
- "not.b16",
- [(set i16:$dst, (not i16:$src))]>;
-def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "not.b32",
- [(set i32:$dst, (not i32:$src))]>;
-def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "not.b64",
- [(set i64:$dst, (not i64:$src))]>;
+ (AND_b32ri $a, imm:$b)>;
+}
+
+foreach t = [I1RT, I16RT, I32RT, I64RT] in
+ def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
+ "not." # t.PtxType,
+ [(set t.Ty:$dst, (not t.Ty:$src))]>;
// Template for left/right shifts. Takes three operands,
// [dest (reg), src (reg), shift (reg or imm)].
@@ -1266,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
//
// This template also defines a 32-bit shift (imm, imm) instruction.
multiclass SHIFT<string OpcStr, SDNode OpNode> {
- def i64rr :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b),
- OpcStr # "64",
- [(set i64:$dst, (OpNode i64:$a, i32:$b))]>;
- def i64ri :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b),
- OpcStr # "64",
- [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>;
- def i32rr :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
- def i32ri :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>;
- def i32ii :
- BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>;
- def i16rr :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b),
- OpcStr # "16",
- [(set i16:$dst, (OpNode i16:$a, i32:$b))]>;
- def i16ri :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b),
- OpcStr # "16",
- [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>;
+ let hasSideEffects = false in {
+ foreach t = [I64RT, I32RT, I16RT] in {
+ def t.Size # _rr :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>;
+ def t.Size # _ri :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>;
+ def t.Size # _ii :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>;
+ }
+ }
}
defm SHL : SHIFT<"shl.b", shl>;
@@ -1301,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>;
defm SRL : SHIFT<"shr.u", srl>;
// Bit-reverse
-def BREV32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
- "brev.b32",
- [(set i32:$dst, (bitreverse i32:$a))]>;
-def BREV64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a),
- "brev.b64",
- [(set i64:$dst, (bitreverse i64:$a))]>;
+foreach t = [I64RT, I32RT] in
+ def BREV_ # t.PtxType :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a),
+ "brev." # t.PtxType,
+ [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>;
//
@@ -1562,10 +1439,7 @@ def SETP_bf16x2rr :
def addr : ComplexPattern<pAny, 2, "SelectADDR">;
-def ADDR_base : Operand<pAny> {
- let PrintMethod = "printOperand";
-}
-
+def ADDR_base : Operand<pAny>;
def ADDR : Operand<pAny> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops ADDR_base, i32imm);
@@ -1579,10 +1453,6 @@ def MmaCode : Operand<i32> {
let PrintMethod = "printMmaCode";
}
-def Offseti32imm : Operand<i32> {
- let PrintMethod = "printOffseti32imm";
-}
-
// Get pointer to local stack.
let hasSideEffects = false in {
def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num),
@@ -1594,33 +1464,31 @@ let hasSideEffects = false in {
// copyPhysreg is hard-coded in NVPTXInstrInfo.cpp
let hasSideEffects = false, isAsCheapAsAMove = true in {
- // Class for register-to-register moves
- class MOVr<RegisterClass RC, string OpStr> :
- BasicNVPTXInst<(outs RC:$dst), (ins RC:$src),
- "mov." # OpStr>;
-
- // Class for immediate-to-register moves
- class MOVi<RegisterClass RC, string OpStr, ValueType VT, Operand IMMType, SDNode ImmNode> :
- BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src),
- "mov." # OpStr,
- [(set VT:$dst, ImmNode:$src)]>;
-}
+ let isMoveReg = true in
+ class MOVr<RegisterClass RC, string OpStr> :
+ BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>;
-def IMOV1r : MOVr<B1, "pred">;
-def MOV16r : MOVr<B16, "b16">;
-def IMOV32r : MOVr<B32, "b32">;
-def IMOV64r : MOVr<B64, "b64">;
-def IMOV128r : MOVr<B128, "b128">;
+ let isMoveImm = true in
+ class MOVi<RegTyInfo t, string suffix> :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src),
+ "mov." # suffix,
+ [(set t.Ty:$dst, t.ImmNode:$src)]>;
+}
+def MOV_B1_r : MOVr<B1, "pred">;
+def MOV_B16_r : MOVr<B16, "b16">;
+def MOV_B32_r : MOVr<B32, "b32">;
+def MOV_B64_r : MOVr<B64, "b64">;
+def MOV_B128_r : MOVr<B128, "b128">;
-def IMOV1i : MOVi<B1, "pred", i1, i1imm, imm>;
-def IMOV16i : MOVi<B16, "b16", i16, i16imm, imm>;
-def IMOV32i : MOVi<B32, "b32", i32, i32imm, imm>;
-def IMOV64i : MOVi<B64, "b64", i64, i64imm, imm>;
-def FMOV16i : MOVi<B16, "b16", f16, f16imm, fpimm>;
-def BFMOV16i : MOVi<B16, "b16", bf16, bf16imm, fpimm>;
-def FMOV32i : MOVi<B32, "b32", f32, f32imm, fpimm>;
-def FMOV64i : MOVi<B64, "b64", f64, f64imm, fpimm>;
+def MOV_B1_i : MOVi<I1RT, "pred">;
+def MOV_B16_i : MOVi<I16RT, "b16">;
+def MOV_B32_i : MOVi<I32RT, "b32">;
+def MOV_B64_i : MOVi<I64RT, "b64">;
+def MOV_F16_i : MOVi<F16RT, "b16">;
+def MOV_BF16_i : MOVi<BF16RT, "b16">;
+def MOV_F32_i : MOVi<F32RT, "b32">;
+def MOV_F64_i : MOVi<F64RT, "b64">;
def to_tglobaladdr : SDNodeXForm<globaladdr, [{
@@ -1638,11 +1506,11 @@ def to_tframeindex : SDNodeXForm<frameindex, [{
return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0));
}]>;
-def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>;
-def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>;
+def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>;
+def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>;
-def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>;
-def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>;
+def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>;
+def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>;
//---- Copy Frame Index ----
def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr),
@@ -1831,7 +1699,6 @@ class LD<NVPTXRegClass regclass>
"\t$dst, [$addr];", []>;
let mayLoad=1, hasSideEffects=0 in {
- def LD_i8 : LD<B16>;
def LD_i16 : LD<B16>;
def LD_i32 : LD<B32>;
def LD_i64 : LD<B64>;
@@ -1847,7 +1714,6 @@ class ST<DAGOperand O>
" \t[$addr], $src;", []>;
let mayStore=1, hasSideEffects=0 in {
- def ST_i8 : ST<RI16>;
def ST_i16 : ST<RI16>;
def ST_i32 : ST<RI32>;
def ST_i64 : ST<RI64>;
@@ -1880,7 +1746,6 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
"[$addr];", []>;
}
let mayLoad=1, hasSideEffects=0 in {
- defm LDV_i8 : LD_VEC<B16>;
defm LDV_i16 : LD_VEC<B16>;
defm LDV_i32 : LD_VEC<B32, support_v8 = true>;
defm LDV_i64 : LD_VEC<B64>;
@@ -1914,7 +1779,6 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
}
let mayStore=1, hasSideEffects=0 in {
- defm STV_i8 : ST_VEC<RI16>;
defm STV_i16 : ST_VEC<RI16>;
defm STV_i32 : ST_VEC<RI32, support_v8 = true>;
defm STV_i64 : ST_VEC<RI64>;
@@ -2084,14 +1948,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>;
// truncate i64
def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>;
def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>;
-def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>;
// truncate i32
def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>;
-def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>;
// truncate i16
-def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>;
// sext_inreg
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
@@ -2335,32 +2199,20 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>;
//-----------------------------------
let isTerminator=1 in {
- let isReturn=1, isBarrier=1 in
+ let isReturn=1, isBarrier=1 in
def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>;
- let isBranch=1 in
- def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
+ let isBranch=1 in {
+ def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
"@$a bra \t$target;",
[(brcond i1:$a, bb:$target)]>;
- let isBranch=1 in
- def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
- "@!$a bra \t$target;", []>;
- let isBranch=1, isBarrier=1 in
+ let isBarrier=1 in
def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target),
- "bra.uni", [(br bb:$target)]>;
+ "bra.uni", [(br bb:$target)]>;
+ }
}
-def : Pat<(brcond i32:$a, bb:$target),
- (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>;
-
-// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a
-// conditional branch if the target block is the next block so that the code
-// can fall through to the target block. The inversion is done by 'xor
-// condition, 1', which will be translated to (setne condition, -1). Since ptx
-// supports '@!pred bra target', we should use it.
-def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target),
- (CBranchOther $a, bb:$target)>;
// trap instruction
def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 0a00220d94289..d33719236b172 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -243,63 +243,82 @@ foreach sync = [false, true] in {
}
// vote.{all,any,uni,ballot}
-multiclass VOTE<NVPTXRegClass regclass, string mode, Intrinsic IntOp> {
- def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred),
- "vote." # mode,
- [(set regclass:$dest, (IntOp i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
-}
+let Predicates = [hasPTX<60>, hasSM<30>] in {
+ multiclass VOTE<string mode, RegTyInfo t, Intrinsic op> {
+ def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred),
+ "vote." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op i1:$pred))]>;
+ }
-defm VOTE_ALL : VOTE<B1, "all.pred", int_nvvm_vote_all>;
-defm VOTE_ANY : VOTE<B1, "any.pred", int_nvvm_vote_any>;
-defm VOTE_UNI : VOTE<B1, "uni.pred", int_nvvm_vote_uni>;
-defm VOTE_BALLOT : VOTE<B32, "ballot.b32", int_nvvm_vote_ballot>;
+ defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>;
+ defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>;
+ defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>;
+ defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>;
+
+ // vote.sync.{all,any,uni,ballot}
+ multiclass VOTE_SYNC<string mode, RegTyInfo t, Intrinsic op> {
+ def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask),
+ "vote.sync." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>;
+ def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask),
+ "vote.sync." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>;
+ }
-// vote.sync.{all,any,uni,ballot}
-multiclass VOTE_SYNC<NVPTXRegClass regclass, string mode, Intrinsic IntOp> {
- def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask),
- "vote.sync." # mode,
- [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
- def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask),
- "vote.sync." # mode,
- [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
+ defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>;
+ defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>;
+ defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>;
+ defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>;
}
-
-defm VOTE_SYNC_ALL : VOTE_SYNC<B1, "all.pred", int_nvvm_vote_all_sync>;
-defm VOTE_SYNC_ANY : VOTE_SYNC<B1, "any.pred", int_nvvm_vote_any_sync>;
-defm VOTE_SYNC_UNI : VOTE_SYNC<B1, "uni.pred", int_nvvm_vote_uni_sync>;
-defm VOTE_SYNC_BALLOT : VOTE_SYNC<B32, "ballot.b32", int_nvvm_vote_ballot_sync>;
-
// elect.sync
+let Predicates = [hasPTX<80>, hasSM<90>] in {
def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask),
"elect.sync",
- [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>;
def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask),
"elect.sync",
- [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>;
+}
+
+let Predicates = [hasPTX<60>, hasSM<70>] in {
+ multiclass MATCH_ANY_SYNC<Intrinsic op, RegTyInfo t> {
+ def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op imm:$mask, imm:$value))]>;
+ def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op i32:$mask, imm:$value))]>;
+ def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>;
+ def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>;
+ }
-multiclass MATCH_ANY_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp,
- Operand ImmOp> {
- def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
+ defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i32, I32RT>;
+ defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i64, I64RT>;
+
+ multiclass MATCH_ALLP_SYNC<RegTyInfo t, Intrinsic op> {
+ def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.Imm:$value, i32imm:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>;
+ def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.Imm:$value, B32:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>;
+ def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.RC:$value, i32imm:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>;
+ def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.RC:$value, B32:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>;
+ }
+ defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<I32RT, int_nvvm_match_all_sync_i32p>;
+ defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<I64RT, int_nvvm_match_all_sync_i64p>;
}
// activemask.b32
@@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins),
[(set i32:$dest, (int_nvvm_activemask))]>,
Requires<[hasPTX<62>, hasSM<30>]>;
-defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<B32, "b32", int_nvvm_match_any_sync_i32,
- i32imm>;
-defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<B64, "b64", int_nvvm_match_any_sync_i64,
- i64imm>;
-
-multiclass MATCH_ALLP_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp,
- Operand ImmOp> {
- def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins ImmOp:$value, i32imm:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins ImmOp:$value, B32:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins regclass:$value, i32imm:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins regclass:$value, B32:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
-}
-defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<B32, "b32", int_nvvm_match_all_sync_i32p,
- i32imm>;
-defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<B64, "b64", int_nvvm_match_all_sync_i64p,
- i64imm>;
-
multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> {
def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask),
"redux.sync." # BinOp # "." # PTXType,
@@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">;
//-----------------------------------
// Explicit Memory Fence Functions
//-----------------------------------
-class MEMBAR<string StrOp, Intrinsic IntOP> :
- BasicNVPTXInst<(outs), (ins),
- StrOp, [(IntOP)]>;
+class NullaryInst<string StrOp, Intrinsic IntOP> :
+ BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>;
-def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>;
-def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>;
-def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>;
+def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>;
+def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>;
+def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>;
def INT_FENCE_SC_CLUSTER:
- MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
+ NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
Requires<[hasPTX<78>, hasSM<90>]>;
// Proxy fence (uni-directional)
-// fence.proxy.tensormap.release variants
-
class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> :
- BasicNVPTXInst<(outs), (ins),
- "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>,
+ NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>,
Requires<[hasPTX<83>, hasSM<90>]>;
def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA:
@@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
int_nvvm_cp_async_cg_shared_global_16_s>;
-def CP_ASYNC_COMMIT_GROUP :
- BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+let Predicates = [hasPTX<70>, hasSM<80>] in {
+ def CP_ASYNC_COMMIT_GROUP :
+ NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>;
-def CP_ASYNC_WAIT_GROUP :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group",
- [(int_nvvm_cp_async_wait_group timm:$n)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+ def CP_ASYNC_WAIT_GROUP :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group",
+ [(int_nvvm_cp_async_wait_group timm:$n)]>;
-def CP_ASYNC_WAIT_ALL :
- BasicNVPTXInst<(outs), (ins), "cp.async.wait_all",
- [(int_nvvm_cp_async_wait_all)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+ def CP_ASYNC_WAIT_ALL :
+ NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>;
+}
-// cp.async.bulk variants of the commit/wait group
-def CP_ASYNC_BULK_COMMIT_GROUP :
- BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group",
- [(int_nvvm_cp_async_bulk_commit_group)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+let Predicates = [hasPTX<80>, hasSM<90>] in {
+ // cp.async.bulk variants of the commit/wait group
+ def CP_ASYNC_BULK_COMMIT_GROUP :
+ NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>;
-def CP_ASYNC_BULK_WAIT_GROUP :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group",
- [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def CP_ASYNC_BULK_WAIT_GROUP :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group",
+ [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>;
-def CP_ASYNC_BULK_WAIT_GROUP_READ :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read",
- [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def CP_ASYNC_BULK_WAIT_GROUP_READ :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read",
+ [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>;
+}
//------------------------------
// TMA Async Bulk Copy Functions
@@ -974,33 +952,30 @@ defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4",
//Prefetch and Prefetchu
-class PREFETCH_INTRS<string InstName> :
- BasicNVPTXInst<(outs), (ins ADDR:$addr),
- InstName,
- [(!cast<Intrinsic>(!strconcat("int_nvvm_",
- !subst(".", "_", InstName))) addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
-
+let Predicates = [hasPTX<80>, hasSM<90>] in {
+ class PREFETCH_INTRS<string InstName> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ InstName,
+ [(!cast<Intrinsic>(!strconcat("int_nvvm_",
+ !subst(".", "_", InstName))) addr:$addr)]>;
-def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
-def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
-def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
-def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
-def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
-def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
+ def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
+ def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
+ def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
+ def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
+ def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
+ def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
-def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_normal",
- [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch.global.L2::evict_normal",
+ [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>;
-def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_last",
- [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch.global.L2::evict_last",
+ [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>;
-
-def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+ def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+}
//Applypriority intrinsics
class APPLYPRIORITY_L2_INTRS<string addrspace> :
@@ -1031,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">;
// MBarrier Functions
//-----------------------------------
-multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count),
- "mbarrier.init" # AddrSpace # ".b64",
- [(Intrin addr:$addr, i32:$count)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
-defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
- int_nvvm_mbarrier_init_shared>;
-
-multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "mbarrier.inval" # AddrSpace # ".b64",
- [(Intrin addr:$addr)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
-defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
- int_nvvm_mbarrier_inval_shared>;
-
-multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
- "mbarrier.arrive" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
-defm MBARRIER_ARRIVE_SHARED :
- MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
-
-multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state),
- (ins ADDR:$addr, B32:$count),
- "mbarrier.arrive.noComplete" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr, i32:$count))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_NOCOMPLETE :
- MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
-defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
- MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
-
-multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
- "mbarrier.arrive_drop" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_DROP :
- MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
-defm MBARRIER_ARRIVE_DROP_SHARED :
- MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
-
-multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state),
- (ins ADDR:$addr, B32:$count),
- "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr, i32:$count))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_DROP_NOCOMPLETE :
- MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
-defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
- MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
- int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
-
-multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state),
- "mbarrier.test_wait" # AddrSpace # ".b64",
- [(set i1:$res, (Intrin addr:$addr, i64:$state))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+let Predicates = [hasPTX<70>, hasSM<80>] in {
+ class MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count),
+ "mbarrier.init" # AddrSpace # ".b64",
+ [(Intrin addr:$addr, i32:$count)]>;
+
+ def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
+ def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
+ int_nvvm_mbarrier_init_shared>;
+
+ class MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "mbarrier.inval" # AddrSpace # ".b64",
+ [(Intrin addr:$addr)]>;
+
+ def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
+ def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
+ int_nvvm_mbarrier_inval_shared>;
+
+ class MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
+ "mbarrier.arrive" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr))]>;
+
+ def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
+ def MBARRIER_ARRIVE_SHARED :
+ MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
+
+ class MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state),
+ (ins ADDR:$addr, B32:$count),
+ "mbarrier.arrive.noComplete" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr, i32:$count))]>;
+
+ def MBARRIER_ARRIVE_NOCOMPLETE :
+ MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
+ def MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
+
+ class MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
+ "mbarrier.arrive_drop" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr))]>;
+
+ def MBARRIER_ARRIVE_DROP :
+ MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
+ def MBARRIER_ARRIVE_DROP_SHARED :
+ MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
+
+ class MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state),
+ (ins ADDR:$addr, B32:$count),
+ "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr, i32:$count))]>;
+
+ def MBARRIER_ARRIVE_DROP_NOCOMPLETE :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
+ def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
+ int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
+
+ class MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state),
+ "mbarrier.test_wait" # AddrSpace # ".b64",
+ [(set i1:$res, (Intrin addr:$addr, i64:$state))]>;
+
+ def MBARRIER_TEST_WAIT :
+ MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
+ def MBARRIER_TEST_WAIT_SHARED :
+ MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
+
+ def MBARRIER_PENDING_COUNT :
+ BasicNVPTXInst<(outs B32:$res), (ins B64:$state),
+ "mbarrier.pending_count.b64",
+ [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>;
}
-
-defm MBARRIER_TEST_WAIT :
- MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
-defm MBARRIER_TEST_WAIT_SHARED :
- MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
-
-class MBARRIER_PENDING_COUNT<Intrinsic Intrin> :
- BasicNVPTXInst<(outs B32:$res), (ins B64:$state),
- "mbarrier.pending_count.b64",
- [(set i32:$res, (Intrin i64:$state))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-
-def MBARRIER_PENDING_COUNT :
- MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>;
-
//-----------------------------------
// Math Functions
//-----------------------------------
@@ -1449,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>;
def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>;
-def COPYSIGN_F :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1),
- "copysign.f32",
- [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>;
-
-def COPYSIGN_D :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1),
- "copysign.f64",
- [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>;
+foreach t = [F32RT, F64RT] in
+ def COPYSIGN_ # t :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1),
+ "copysign." # t.PtxType,
+ [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>;
//
// Neg bf16, bf16x2
@@ -2255,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">;
// Scalar
-class LDU_G<string TyStr, NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$result), (ins ADDR:$src),
- "ldu.global." # TyStr # " \t$result, [$src];", []>;
+class LDU_G<NVPTXRegClass regclass>
+ : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.b$fromWidth \t$result, [$src];", []>;
-def LDU_GLOBAL_i8 : LDU_G<"b8", B16>;
-def LDU_GLOBAL_i16 : LDU_G<"b16", B16>;
-def LDU_GLOBAL_i32 : LDU_G<"b32", B32>;
-def LDU_GLOBAL_i64 : LDU_G<"b64", B64>;
+def LDU_GLOBAL_i16 : LDU_G<B16>;
+def LDU_GLOBAL_i32 : LDU_G<B32>;
+def LDU_GLOBAL_i64 : LDU_G<B64>;
// vector
// Elementized vector ldu
-class VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass>
+class VLDU_G_ELE_V2<NVPTXRegClass regclass>
: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
- (ins ADDR:$src),
- "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>;
+ (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>;
-class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins ADDR:$src),
- "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
+class VLDU_G_ELE_V4<NVPTXRegClass regclass>
+ : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
+ (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
-def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>;
-def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>;
-def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>;
-def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>;
+def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<B16>;
+def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<B32>;
+def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<B64>;
-def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>;
-def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>;
-def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>;
+def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<B16>;
+def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
//-----------------------------------
@@ -2327,12 +2278,10 @@ class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
"ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
-def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2<B16>;
def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2<B16>;
def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2<B32>;
def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2<B64>;
-def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4<B16>;
def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4<B16>;
def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4<B32>;
@@ -2342,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8<B32>;
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src),
- "cvta." # Str # ".u32", []>, Requires<Preds>;
+ "cvta." # Str # ".u32">, Requires<Preds>;
def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src),
- "cvta." # Str # ".u64", []>, Requires<Preds>;
+ "cvta." # Str # ".u64">, Requires<Preds>;
}
multiclass G_TO_NG<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src),
- "cvta.to." # Str # ".u32", []>, Requires<Preds>;
+ "cvta.to." # Str # ".u32">, Requires<Preds>;
def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src),
- "cvta.to." # Str # ".u64", []>, Requires<Preds>;
+ "cvta.to." # Str # ".u64">, Requires<Preds>;
}
foreach space = ["local", "shared", "global", "const", "param"] in {
@@ -4614,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT :
PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>;
let hasSideEffects = 1 in {
-def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
-def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
-def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
+ def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
+ def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
+ def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
}
def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>;
@@ -5096,37 +5045,36 @@ foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
- def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b),
- "mapa" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a, i32:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b),
- "mapa" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a, imm:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b),
- "mapa" # suffix # ".u64",
- [(set i64:$d, (Intr i64:$a, i32:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b),
- "mapa" # suffix # ".u64",
- [(set i64:$d, (Intr i64:$a, imm:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+ let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b),
+ "mapa" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a, i32:$b))]>;
+ def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b),
+ "mapa" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a, imm:$b))]>;
+ def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b),
+ "mapa" # suffix # ".u64",
+ [(set i64:$d, (Intr i64:$a, i32:$b))]>;
+ def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b),
+ "mapa" # suffix # ".u64",
+ [(set i64:$d, (Intr i64:$a, imm:$b))]>;
+ }
}
+
defm mapa : MAPA<"", int_nvvm_mapa>;
defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>;
multiclass GETCTARANK<string suffix, Intrinsic Intr> {
- def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a),
- "getctarank" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a),
- "getctarank" # suffix # ".u64",
- [(set i32:$d, (Intr i64:$a))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+ let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a),
+ "getctarank" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a))]>;
+ def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a),
+ "getctarank" # suffix # ".u64",
+ [(set i32:$d, (Intr i64:$a))]>;
+ }
}
defm getctarank : GETCTARANK<"", int_nvvm_getctarank>;
@@ -5165,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm:
[(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>;
} // isConvergent = true
-def GRIDDEPCONTROL_LAUNCH_DEPENDENTS :
- BasicNVPTXInst<(outs), (ins),
- "griddepcontrol.launch_dependents",
- [(int_nvvm_griddepcontrol_launch_dependents)]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
-
-def GRIDDEPCONTROL_WAIT :
- BasicNVPTXInst<(outs), (ins),
- "griddepcontrol.wait",
- [(int_nvvm_griddepcontrol_wait)]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def GRIDDEPCONTROL_LAUNCH_DEPENDENTS :
+ BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents",
+ [(int_nvvm_griddepcontrol_launch_dependents)]>;
+ def GRIDDEPCONTROL_WAIT :
+ BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait",
+ [(int_nvvm_griddepcontrol_wait)]>;
+}
def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>;
// Tcgen05 intrinsics
-let isConvergent = true in {
+let isConvergent = true, Predicates = [hasTcgen05Instructions] in {
multiclass TCGEN05_ALLOC_INTR<string AS, string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins ADDR:$dst, B32:$ncols),
"tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32",
- [(Intr addr:$dst, B32:$ncols)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$dst, B32:$ncols)]>;
}
defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>;
@@ -5200,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins B32:$tmem_addr, B32:$ncols),
"tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32",
- [(Intr B32:$tmem_addr, B32:$ncols)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr B32:$tmem_addr, B32:$ncols)]>;
}
defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>;
defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>;
@@ -5209,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2
multiclass TCGEN05_RELINQ_PERMIT_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs), (ins),
"tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned",
- [(Intr)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr)]>;
}
defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>;
defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>;
-def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned",
- [(int_nvvm_tcgen05_wait_ld)]>,
- Requires<[hasTcgen05Instructions]>;
-
-def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned",
- [(int_nvvm_tcgen05_wait_st)]>,
- Requires<[hasTcgen05Instructions]>;
+def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>;
+def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>;
multiclass TCGEN05_COMMIT_INTR<string AS, string num> {
defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster";
@@ -5232,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR<string AS, string num> {
def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar),
prefix # ".b64",
- [(Intr addr:$mbar)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$mbar)]>;
def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc),
prefix # ".multicast::cluster.b64",
- [(IntrMC addr:$mbar, B16:$mc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrMC addr:$mbar, B16:$mc)]>;
}
defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">;
@@ -5249,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr),
"tcgen05.shift.cta_group::" # num # ".down",
- [(Intr addr:$tmem_addr)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$tmem_addr)]>;
}
defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>;
defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>;
@@ -5270,13 +5204,11 @@ multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> {
def _cg1 : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr, B64:$sdesc),
"tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm,
- [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>;
def _cg2 : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr, B64:$sdesc),
"tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm,
- [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>;
}
foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
@@ -5289,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
}
} // isConvergent
-let hasSideEffects = 1 in {
+let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in {
-def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins),
- "tcgen05.fence::before_thread_sync",
- [(int_nvvm_tcgen05_fence_before_thread_sync)]>,
- Requires<[hasTcgen05Instructions]>;
+ def tcgen05_fence_before_thread_sync: NullaryInst<
+ "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>;
-def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins),
- "tcgen05.fence::after_thread_sync",
- [(int_nvvm_tcgen05_fence_after_thread_sync)]>,
- Requires<[hasTcgen05Instructions]>;
+ def tcgen05_fence_after_thread_sync: NullaryInst<
+ "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>;
} // hasSideEffects
@@ -5392,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in {
// Bulk store instructions
def st_bulk_imm : TImmLeaf<i64, [{ return Imm == 0; }]>;
-def INT_NVVM_ST_BULK_GENERIC :
- BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
- "st.bulk",
- [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>,
- Requires<[hasSM<100>, hasPTX<86>]>;
+let Predicates = [hasSM<100>, hasPTX<86>] in {
+ def INT_NVVM_ST_BULK_GENERIC :
+ BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
+ "st.bulk",
+ [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>;
-def INT_NVVM_ST_BULK_SHARED_CTA:
- BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
- "st.bulk.shared::cta",
- [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>,
- Requires<[hasSM<100>, hasPTX<86>]>;
+ def INT_NVVM_ST_BULK_SHARED_CTA:
+ BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
+ "st.bulk.shared::cta",
+ [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>;
+}
//
// clusterlaunchcontorl Instructions
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
index d40886a56d6a4..2e81ab122d1df 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -38,14 +38,6 @@ foreach i = 0...4 in {
def R#i : NVPTXReg<"%r"#i>; // 32-bit
def RL#i : NVPTXReg<"%rd"#i>; // 64-bit
def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
- def H#i : NVPTXReg<"%h"#i>; // 16-bit float
- def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
-
- // Arguments
- def ia#i : NVPTXReg<"%ia"#i>;
- def la#i : NVPTXReg<"%la"#i>;
- def fa#i : NVPTXReg<"%fa"#i>;
- def da#i : NVPTXReg<"%da"#i>;
}
foreach i = 0...31 in {
More information about the llvm-commits
mailing list