[llvm] [NVPTX][NFC] Minor cleanup in NVPTXInstrInfo.td (PR #138006)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 30 11:08:39 PDT 2025


https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/138006

None

>From 510e00d62f534c8e13e48530581281ca2f970e4b Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 30 Apr 2025 15:55:37 +0000
Subject: [PATCH] [NVPTX][NFC] Minor cleanup in NVPTXInstrInfo.td

---
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp |   2 -
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td  | 124 ++++++++---------------
 2 files changed, 40 insertions(+), 86 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index 0551954444e57..67dc7904a91ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -209,9 +209,7 @@ bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
   switch (MI.getOpcode()) {
   case NVPTX::CallUniPrintCallRetInst1:
   case NVPTX::CallArgBeginInst:
-  case NVPTX::CallArgI32imm:
   case NVPTX::CallArgParam:
-  case NVPTX::LastCallArgI32imm:
   case NVPTX::LastCallArgParam:
   case NVPTX::CallArgEndInst1:
     return true;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 043da14bcb236..11d77599d4ac3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1330,58 +1330,46 @@ def FDIV32ri_prec :
 // FMA
 //
 
-multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
+multiclass FMA<string OpcStr, RegTyInfo t, list<Predicate> Preds = []> {
   defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;";
-  def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+  def rrr : NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
                       asmstr,
-                      [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
-                      Requires<[Pred]>;
-  def rri : NVPTXInst<(outs RC:$dst),
-                      (ins RC:$a, RC:$b, ImmCls:$c),
-                      asmstr,
-                      [(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
-                      Requires<[Pred]>;
-  def rir : NVPTXInst<(outs RC:$dst),
-                      (ins RC:$a, ImmCls:$b, RC:$c),
-                      asmstr,
-                      [(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
-                      Requires<[Pred]>;
-  def rii : NVPTXInst<(outs RC:$dst),
-                      (ins RC:$a, ImmCls:$b, ImmCls:$c),
-                      asmstr,
-                      [(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
-                      Requires<[Pred]>;
-  def iir : NVPTXInst<(outs RC:$dst),
-                      (ins ImmCls:$a, ImmCls:$b, RC:$c),
-                      asmstr,
-                      [(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>,
-                      Requires<[Pred]>;
-
-}
-
-multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
-   def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
-                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                       [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
-                       Requires<[useFP16Math, Pred]>;
-}
-
-multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
-   def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
-                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                       [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
-                       Requires<[hasBF16Math, Pred]>;
+                      [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
+                      Requires<Preds>;
+
+  if t.SupportsImm then {
+    def rri : NVPTXInst<(outs t.RC:$dst),
+                        (ins t.RC:$a, t.RC:$b, t.Imm:$c),
+                        asmstr,
+                        [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
+                        Requires<Preds>;
+    def rir : NVPTXInst<(outs t.RC:$dst),
+                        (ins t.RC:$a, t.Imm:$b, t.RC:$c),
+                        asmstr,
+                        [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
+                        Requires<Preds>;
+    def rii : NVPTXInst<(outs t.RC:$dst),
+                        (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
+                        asmstr,
+                        [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
+                        Requires<Preds>;
+    def iir : NVPTXInst<(outs t.RC:$dst),
+                        (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
+                        asmstr,
+                        [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
+                        Requires<Preds>;
+  }
 }
 
-defm FMA16_ftz    : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
-defm FMA16        : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
-defm FMA16x2_ftz  : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
-defm FMA16x2      : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
-defm BFMA16       : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
-defm BFMA16x2     : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
-defm FMA32_ftz    : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
-defm FMA32        : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
-defm FMA64        : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
+defm FMA16_ftz    : FMA<"fma.rn.ftz.f16", F16RT, [useFP16Math, doF32FTZ]>;
+defm FMA16        : FMA<"fma.rn.f16", F16RT, [useFP16Math]>;
+defm FMA16x2_ftz  : FMA<"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math, doF32FTZ]>;
+defm FMA16x2      : FMA<"fma.rn.f16x2", F16X2RT, [useFP16Math]>;
+defm BFMA16       : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
+defm BFMA16x2     : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
+defm FMA32_ftz    : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
+defm FMA32        : FMA<"fma.rn.f32", F32RT>;
+defm FMA64        : FMA<"fma.rn.f64", F64RT>;
 
 // sin/cos
 
@@ -1999,7 +1987,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
         Requires<[doF32FTZ]>;
   def : Pat<(i1 (OpNode f32:$a, f32:$b)),
             (SETP_f32rr $a, $b, Mode)>;
-  def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)),
+  def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
             (SETP_f32ri $a, fpimm:$b, ModeFTZ)>,
         Requires<[doF32FTZ]>;
   def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
@@ -2056,7 +2044,7 @@ def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
+def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
 def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
 def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
 def SDTCallValProfile : SDTypeProfile<1, 0, []>;
@@ -2352,42 +2340,10 @@ def CallArgEndInst1  : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
 def CallArgEndInst0  : NVPTXInst<(outs), (ins), ")", [(CallArgEnd (i32 0))]>;
 def RETURNInst       : NVPTXInst<(outs), (ins), "ret;", [(RETURNNode)]>;
 
-class CallArgInst<NVPTXRegClass regclass> :
-  NVPTXInst<(outs), (ins regclass:$a), "$a, ",
-            [(CallArg (i32 0), regclass:$a)]>;
-
-class CallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
-  NVPTXInst<(outs), (ins regclass:$a), "$a, ",
-            [(CallArg (i32 0), vt:$a)]>;
-
-class LastCallArgInst<NVPTXRegClass regclass> :
-  NVPTXInst<(outs), (ins regclass:$a), "$a",
-            [(LastCallArg (i32 0), regclass:$a)]>;
-class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
-  NVPTXInst<(outs), (ins regclass:$a), "$a",
-            [(LastCallArg (i32 0), vt:$a)]>;
-
-def CallArgI64     : CallArgInst<Int64Regs>;
-def CallArgI32     : CallArgInstVT<Int32Regs, i32>;
-def CallArgI16     : CallArgInstVT<Int16Regs, i16>;
-def CallArgF64     : CallArgInst<Float64Regs>;
-def CallArgF32     : CallArgInst<Float32Regs>;
-
-def LastCallArgI64 : LastCallArgInst<Int64Regs>;
-def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
-def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
-def LastCallArgF64 : LastCallArgInst<Float64Regs>;
-def LastCallArgF32 : LastCallArgInst<Float32Regs>;
-
-def CallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a, ",
-                              [(CallArg (i32 0), (i32 imm:$a))]>;
-def LastCallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a",
-                                  [(LastCallArg (i32 0), (i32 imm:$a))]>;
-
 def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ",
-                             [(CallArg (i32 1), (i32 imm:$a))]>;
+                             [(CallArg 1, imm:$a)]>;
 def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a",
-                                 [(LastCallArg (i32 1), (i32 imm:$a))]>;
+                                 [(LastCallArg 1, imm:$a)]>;
 
 def CallVoidInst :      NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ",
                                   [(CallVoid (Wrapper tglobaladdr:$addr))]>;



More information about the llvm-commits mailing list