[llvm] e6ffccb - [NVPTX][NFC] Minor cleanup in NVPTXInstrInfo.td (#138006)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 30 15:03:55 PDT 2025
Author: Alex MacLean
Date: 2025-04-30T15:03:51-07:00
New Revision: e6ffccbaa7c908839af150be8f7e21e0f8844dcc
URL: https://github.com/llvm/llvm-project/commit/e6ffccbaa7c908839af150be8f7e21e0f8844dcc
DIFF: https://github.com/llvm/llvm-project/commit/e6ffccbaa7c908839af150be8f7e21e0f8844dcc.diff
LOG: [NVPTX][NFC] Minor cleanup in NVPTXInstrInfo.td (#138006)
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index 0551954444e57..67dc7904a91ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -209,9 +209,7 @@ bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
switch (MI.getOpcode()) {
case NVPTX::CallUniPrintCallRetInst1:
case NVPTX::CallArgBeginInst:
- case NVPTX::CallArgI32imm:
case NVPTX::CallArgParam:
- case NVPTX::LastCallArgI32imm:
case NVPTX::LastCallArgParam:
case NVPTX::CallArgEndInst1:
return true;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 043da14bcb236..11d77599d4ac3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1330,58 +1330,46 @@ def FDIV32ri_prec :
// FMA
//
-multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
+multiclass FMA<string OpcStr, RegTyInfo t, list<Predicate> Preds = []> {
defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;";
- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ def rrr : NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
asmstr,
- [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
- Requires<[Pred]>;
- def rri : NVPTXInst<(outs RC:$dst),
- (ins RC:$a, RC:$b, ImmCls:$c),
- asmstr,
- [(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
- Requires<[Pred]>;
- def rir : NVPTXInst<(outs RC:$dst),
- (ins RC:$a, ImmCls:$b, RC:$c),
- asmstr,
- [(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
- Requires<[Pred]>;
- def rii : NVPTXInst<(outs RC:$dst),
- (ins RC:$a, ImmCls:$b, ImmCls:$c),
- asmstr,
- [(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
- Requires<[Pred]>;
- def iir : NVPTXInst<(outs RC:$dst),
- (ins ImmCls:$a, ImmCls:$b, RC:$c),
- asmstr,
- [(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>,
- Requires<[Pred]>;
-
-}
-
-multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
- !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
- [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
- Requires<[useFP16Math, Pred]>;
-}
-
-multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
- !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
- [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
- Requires<[hasBF16Math, Pred]>;
+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
+ Requires<Preds>;
+
+ if t.SupportsImm then {
+ def rri : NVPTXInst<(outs t.RC:$dst),
+ (ins t.RC:$a, t.RC:$b, t.Imm:$c),
+ asmstr,
+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
+ Requires<Preds>;
+ def rir : NVPTXInst<(outs t.RC:$dst),
+ (ins t.RC:$a, t.Imm:$b, t.RC:$c),
+ asmstr,
+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
+ Requires<Preds>;
+ def rii : NVPTXInst<(outs t.RC:$dst),
+ (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
+ asmstr,
+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
+ Requires<Preds>;
+ def iir : NVPTXInst<(outs t.RC:$dst),
+ (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
+ asmstr,
+ [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
+ Requires<Preds>;
+ }
}
-defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
-defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
-defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
-defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
-defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
-defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
-defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
-defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
-defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
+defm FMA16_ftz : FMA<"fma.rn.ftz.f16", F16RT, [useFP16Math, doF32FTZ]>;
+defm FMA16 : FMA<"fma.rn.f16", F16RT, [useFP16Math]>;
+defm FMA16x2_ftz : FMA<"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math, doF32FTZ]>;
+defm FMA16x2 : FMA<"fma.rn.f16x2", F16X2RT, [useFP16Math]>;
+defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
+defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
+defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
+defm FMA32 : FMA<"fma.rn.f32", F32RT>;
+defm FMA64 : FMA<"fma.rn.f64", F64RT>;
// sin/cos
@@ -1999,7 +1987,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
Requires<[doF32FTZ]>;
def : Pat<(i1 (OpNode f32:$a, f32:$b)),
(SETP_f32rr $a, $b, Mode)>;
- def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)),
+ def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
(SETP_f32ri $a, fpimm:$b, ModeFTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
@@ -2056,7 +2044,7 @@ def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
+def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
@@ -2352,42 +2340,10 @@ def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
def CallArgEndInst0 : NVPTXInst<(outs), (ins), ")", [(CallArgEnd (i32 0))]>;
def RETURNInst : NVPTXInst<(outs), (ins), "ret;", [(RETURNNode)]>;
-class CallArgInst<NVPTXRegClass regclass> :
- NVPTXInst<(outs), (ins regclass:$a), "$a, ",
- [(CallArg (i32 0), regclass:$a)]>;
-
-class CallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
- NVPTXInst<(outs), (ins regclass:$a), "$a, ",
- [(CallArg (i32 0), vt:$a)]>;
-
-class LastCallArgInst<NVPTXRegClass regclass> :
- NVPTXInst<(outs), (ins regclass:$a), "$a",
- [(LastCallArg (i32 0), regclass:$a)]>;
-class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
- NVPTXInst<(outs), (ins regclass:$a), "$a",
- [(LastCallArg (i32 0), vt:$a)]>;
-
-def CallArgI64 : CallArgInst<Int64Regs>;
-def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
-def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
-def CallArgF64 : CallArgInst<Float64Regs>;
-def CallArgF32 : CallArgInst<Float32Regs>;
-
-def LastCallArgI64 : LastCallArgInst<Int64Regs>;
-def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
-def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
-def LastCallArgF64 : LastCallArgInst<Float64Regs>;
-def LastCallArgF32 : LastCallArgInst<Float32Regs>;
-
-def CallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a, ",
- [(CallArg (i32 0), (i32 imm:$a))]>;
-def LastCallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a",
- [(LastCallArg (i32 0), (i32 imm:$a))]>;
-
def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ",
- [(CallArg (i32 1), (i32 imm:$a))]>;
+ [(CallArg 1, imm:$a)]>;
def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a",
- [(LastCallArg (i32 1), (i32 imm:$a))]>;
+ [(LastCallArg 1, imm:$a)]>;
def CallVoidInst : NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ",
[(CallVoid (Wrapper tglobaladdr:$addr))]>;
More information about the llvm-commits
mailing list