[llvm] [NVPTX] Rename register classes after float register removal (NFC) (PR #145255)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 22 19:33:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>



---

Patch is 368.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145255.diff


7 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+5-5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+17-17) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+5-5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+443-478) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+1423-2096) 
- (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (+10-10) 
- (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+5-13) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9af6fb2cb198e..38912a7f09e30 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -215,15 +215,15 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
     // Encode the register class in the upper 4 bits
     // Must be kept in sync with NVPTXInstPrinter::printRegName
     unsigned Ret = 0;
-    if (RC == &NVPTX::Int1RegsRegClass) {
+    if (RC == &NVPTX::B1RegClass) {
       Ret = (1 << 28);
-    } else if (RC == &NVPTX::Int16RegsRegClass) {
+    } else if (RC == &NVPTX::B16RegClass) {
       Ret = (2 << 28);
-    } else if (RC == &NVPTX::Int32RegsRegClass) {
+    } else if (RC == &NVPTX::B32RegClass) {
       Ret = (3 << 28);
-    } else if (RC == &NVPTX::Int64RegsRegClass) {
+    } else if (RC == &NVPTX::B64RegClass) {
       Ret = (4 << 28);
-    } else if (RC == &NVPTX::Int128RegsRegClass) {
+    } else if (RC == &NVPTX::B128RegClass) {
       Ret = (7 << 28);
     } else {
       report_fatal_error("Bad register class");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 492f4ab76fdbb..676654d6d33e7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -589,18 +589,18 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
   };
 
-  addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
-  addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
-  addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
-  addRegisterClass(MVT::f32, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::f64, &NVPTX::Int64RegsRegClass);
-  addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
-  addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
-  addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
+  addRegisterClass(MVT::i1, &NVPTX::B1RegClass);
+  addRegisterClass(MVT::i16, &NVPTX::B16RegClass);
+  addRegisterClass(MVT::v2i16, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::v4i8, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::i32, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::i64, &NVPTX::B64RegClass);
+  addRegisterClass(MVT::f32, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::f64, &NVPTX::B64RegClass);
+  addRegisterClass(MVT::f16, &NVPTX::B16RegClass);
+  addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
+  addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
 
   // Conversion to/from FP16/FP16x2 is always legal.
   setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -4866,22 +4866,22 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   if (Constraint.size() == 1) {
     switch (Constraint[0]) {
     case 'b':
-      return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B1RegClass);
     case 'c':
     case 'h':
-      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B16RegClass);
     case 'r':
     case 'f':
-      return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B32RegClass);
     case 'l':
     case 'N':
     case 'd':
-      return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B64RegClass);
     case 'q': {
       if (STI.getSmVersion() < 70)
         report_fatal_error("Inline asm with 128 bit operands is only "
                            "supported for sm_70 and higher!");
-      return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B128RegClass);
     }
     }
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index f262a0fb66c25..bf84d1dca4ed5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -39,15 +39,15 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     report_fatal_error("Copy one register into another with a different width");
 
   unsigned Op;
-  if (DestRC == &NVPTX::Int1RegsRegClass) {
+  if (DestRC == &NVPTX::B1RegClass) {
     Op = NVPTX::IMOV1r;
-  } else if (DestRC == &NVPTX::Int16RegsRegClass) {
+  } else if (DestRC == &NVPTX::B16RegClass) {
     Op = NVPTX::MOV16r;
-  } else if (DestRC == &NVPTX::Int32RegsRegClass) {
+  } else if (DestRC == &NVPTX::B32RegClass) {
     Op = NVPTX::IMOV32r;
-  } else if (DestRC == &NVPTX::Int64RegsRegClass) {
+  } else if (DestRC == &NVPTX::B64RegClass) {
     Op = NVPTX::IMOV64r;
-  } else if (DestRC == &NVPTX::Int128RegsRegClass) {
+  } else if (DestRC == &NVPTX::B128RegClass) {
     Op = NVPTX::IMOV128r;
   } else {
     llvm_unreachable("Bad register copy");
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index bbe99dec5c445..5979054764647 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,29 +170,6 @@ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
 def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">;
 
-// Helper class to aid conversion between ValueType and a matching RegisterClass.
-
-class ValueToRegClass<ValueType T> {
-   string name = !cast<string>(T);
-   NVPTXRegClass ret = !cond(
-     !eq(name, "i1"): Int1Regs,
-     !eq(name, "i16"): Int16Regs,
-     !eq(name, "v2i16"): Int32Regs,
-     !eq(name, "i32"): Int32Regs,
-     !eq(name, "i64"): Int64Regs,
-     !eq(name, "f16"): Int16Regs,
-     !eq(name, "v2f16"): Int32Regs,
-     !eq(name, "bf16"): Int16Regs,
-     !eq(name, "v2bf16"): Int32Regs,
-     !eq(name, "f32"): Float32Regs,
-     !eq(name, "f64"): Float64Regs,
-     !eq(name, "ai32"): Int32ArgRegs,
-     !eq(name, "ai64"): Int64ArgRegs,
-     !eq(name, "af32"): Float32ArgRegs,
-     !eq(name, "if64"): Float64ArgRegs,
-    );
-}
-
 
 //===----------------------------------------------------------------------===//
 // Some Common Instruction Class Templates
@@ -219,18 +196,18 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
   int Size = ty.Size;
 }
 
-def I1RT     : RegTyInfo<i1,  Int1Regs,  i1imm,  imm>;
-def I16RT    : RegTyInfo<i16, Int16Regs, i16imm, imm>;
-def I32RT    : RegTyInfo<i32, Int32Regs, i32imm, imm>;
-def I64RT    : RegTyInfo<i64, Int64Regs, i64imm, imm>;
+def I1RT     : RegTyInfo<i1,  B1,  i1imm,  imm>;
+def I16RT    : RegTyInfo<i16, B16, i16imm, imm>;
+def I32RT    : RegTyInfo<i32, B32, i32imm, imm>;
+def I64RT    : RegTyInfo<i64, B64, i64imm, imm>;
 
-def F32RT    : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
-def F64RT    : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
-def F16RT    : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
-def BF16RT   : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
+def F32RT    : RegTyInfo<f32, B32, f32imm, fpimm>;
+def F64RT    : RegTyInfo<f64, B64, f64imm, fpimm>;
+def F16RT    : RegTyInfo<f16, B16, f16imm, fpimm, supports_imm = 0>;
+def BF16RT   : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;
 
-def F16X2RT  : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
-def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
+def F16X2RT  : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
 
 
 // This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -238,18 +215,18 @@ def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
 // construction of the asm string based on the provided dag arguments.
 // For example, the following asm-strings would be computed:
 //
-//   * BasicFlagsNVPTXInst<(outs Int32Regs:$dst),
-//                         (ins Int32Regs:$a, Int32Regs:$b), (ins),
+//   * BasicFlagsNVPTXInst<(outs B32:$dst),
+//                         (ins B32:$a, B32:$b), (ins),
 //                         "add.s32">;
 //         ---> "add.s32 \t$dst, $a, $b;"
 //
-//   * BasicFlagsNVPTXInst<(outs Int32Regs:$d),
-//                         (ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
+//   * BasicFlagsNVPTXInst<(outs B32:$d),
+//                         (ins B32:$a, B32:$b, Hexu32imm:$c),
 //                         (ins PrmtMode:$mode),
 //                         "prmt.b32${mode}">;
 //         ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
 //
-//   * BasicFlagsNVPTXInst<(outs Int64Regs:$state),
+//   * BasicFlagsNVPTXInst<(outs B64:$state),
 //                         (ins ADDR:$addr),
 //                         "mbarrier.arrive.b64">;
 //         ---> "mbarrier.arrive.b64 \t$state, [$addr];"
@@ -312,7 +289,7 @@ multiclass I3<string op_str, SDPatternOperator op_node, bit commutative> {
 }
 
 class I16x2<string OpcStr, SDNode OpNode> :
-  BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
+  BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
               OpcStr # "16x2",
               [(set v2i16:$dst, (OpNode v2i16:$a, v2i16:$b))]>,
               Requires<[hasPTX<80>, hasSM<90>]>;
@@ -334,73 +311,73 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
 multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
   if !not(NaN) then {
    def f64rr :
-     BasicNVPTXInst<(outs Float64Regs:$dst),
-               (ins Float64Regs:$a, Float64Regs:$b),
+     BasicNVPTXInst<(outs B64:$dst),
+               (ins B64:$a, B64:$b),
                OpcStr # ".f64",
                [(set f64:$dst, (OpNode f64:$a, f64:$b))]>;
    def f64ri :
-     BasicNVPTXInst<(outs Float64Regs:$dst),
-               (ins Float64Regs:$a, f64imm:$b),
+     BasicNVPTXInst<(outs B64:$dst),
+               (ins B64:$a, f64imm:$b),
                OpcStr # ".f64",
                [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>;
   }
    def f32rr_ftz :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, Float32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".ftz.f32",
                [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
                Requires<[doF32FTZ]>;
    def f32ri_ftz :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, f32imm:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, f32imm:$b),
                OpcStr # ".ftz.f32",
                [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[doF32FTZ]>;
    def f32rr :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, Float32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".f32",
                [(set f32:$dst, (OpNode f32:$a, f32:$b))]>;
    def f32ri :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, f32imm:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, f32imm:$b),
                OpcStr # ".f32",
                [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>;
 
    def f16rr_ftz :
-     BasicNVPTXInst<(outs Int16Regs:$dst),
-               (ins Int16Regs:$a, Int16Regs:$b),
+     BasicNVPTXInst<(outs B16:$dst),
+               (ins B16:$a, B16:$b),
                OpcStr # ".ftz.f16",
                [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, doF32FTZ]>;
    def f16rr :
-     BasicNVPTXInst<(outs Int16Regs:$dst),
-               (ins Int16Regs:$a, Int16Regs:$b),
+     BasicNVPTXInst<(outs B16:$dst),
+               (ins B16:$a, B16:$b),
                OpcStr # ".f16",
                [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
 
    def f16x2rr_ftz :
-     BasicNVPTXInst<(outs Int32Regs:$dst),
-               (ins Int32Regs:$a, Int32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".ftz.f16x2",
                [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
    def f16x2rr :
-     BasicNVPTXInst<(outs Int32Regs:$dst),
-               (ins Int32Regs:$a, Int32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".f16x2",
                [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
    def bf16rr :
-     BasicNVPTXInst<(outs Int16Regs:$dst),
-               (ins Int16Regs:$a, Int16Regs:$b),
+     BasicNVPTXInst<(outs B16:$dst),
+               (ins B16:$a, B16:$b),
                OpcStr # ".bf16",
                [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
    def bf16x2rr :
-     BasicNVPTXInst<(outs Int32Regs:$dst),
-               (ins Int32Regs:$a, Int32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".bf16x2",
                [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
@@ -417,73 +394,73 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
 // just like the non ".rn" op, but prevents ptxas from creating FMAs.
 multiclass F3<string op_str, SDPatternOperator op_pat> {
   def f64rr :
-    BasicNVPTXInst<(outs Float64Regs:$dst),
-              (ins Float64Regs:$a, Float64Regs:$b),
+    BasicNVPTXInst<(outs B64:$dst),
+              (ins B64:$a, B64:$b),
               op_str # ".f64",
               [(set f64:$dst, (op_pat f64:$a, f64:$b))]>;
   def f64ri :
-    BasicNVPTXInst<(outs Float64Regs:$dst),
-              (ins Float64Regs:$a, f64imm:$b),
+    BasicNVPTXInst<(outs B64:$dst),
+              (ins B64:$a, f64imm:$b),
               op_str # ".f64",
               [(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
   def f32rr_ftz :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, Float32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".ftz.f32",
               [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
               Requires<[doF32FTZ]>;
   def f32ri_ftz :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, f32imm:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, f32imm:$b),
               op_str # ".ftz.f32",
               [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
               Requires<[doF32FTZ]>;
   def f32rr :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, Float32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".f32",
               [(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
   def f32ri :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, f32imm:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, f32imm:$b),
               op_str # ".f32",
               [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
 
   def f16rr_ftz :
-    BasicNVPTXInst<(outs Int16Regs:$dst),
-              (ins Int16Regs:$a, Int16Regs:$b),
+    BasicNVPTXInst<(outs B16:$dst),
+              (ins B16:$a, B16:$b),
               op_str # ".ftz.f16",
               [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
               Requires<[useFP16Math, doF32FTZ]>;
   def f16rr :
-    BasicNVPTXInst<(outs Int16Regs:$dst),
-              (ins Int16Regs:$a, Int16Regs:$b),
+    BasicNVPTXInst<(outs B16:$dst),
+              (ins B16:$a, B16:$b),
               op_str # ".f16",
               [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
               Requires<[useFP16Math]>;
 
   def f16x2rr_ftz :
-    BasicNVPTXInst<(outs Int32Regs:$dst),
-              (ins Int32Regs:$a, Int32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".ftz.f16x2",
               [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
               Requires<[useFP16Math, doF32FTZ]>;
   def f16x2rr :
-    BasicNVPTXInst<(outs Int32Regs:$dst),
-              (ins Int32Regs:$a, Int32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".f16x2",
               [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
               Requires<[useFP16Math]>;
   def bf16rr :
-    BasicNVPTXInst<(outs Int16Regs:$dst),
-              (ins Int16Regs:$a, Int16Regs:$b),
+    BasicNVPTXInst<(outs B16:$dst),
+              (ins B16:$a, B16:$b),
               op_str # ".bf16",
               [(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
               Requires<[hasBF16Math]>;
 
   def bf16x2rr :
-    BasicNVPTXInst<(outs Int32Regs:$dst),
-              (ins Int32Regs:$a, Int32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".bf16x2",
               [(set v2bf16:$dst, (op_pat v2bf16:$a, v2bf16:$b))]>,
               Requires<[hasBF16Math]>;
@@ -504,40 +481,40 @@ multiclass F3_fma_component<string op_str, SDNode op_node> {
 // instructions: <OpcStr>.f64, <OpcStr>.f32, and <OpcStr>.ftz.f32 (flush
 // subnormal inputs and results to zero).
 multiclass F2<string OpcStr, SDNode OpNode> {
-   def f64 :     BasicNVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a),
+   def f64 :     BasicNVPTXInst<(outs B64:$dst), (ins B64:$a),
                            OpcStr # ".f64",
                            [(set f64:$dst, (OpNode f64:$a))]>;
-   def f32_ftz : BasicNVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a),
+   def f32_ftz : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".ftz.f32",
                            [(set f32:$dst, (OpNode f32:$a))]>,
                            Requires<[doF32FTZ]>;
-   def f32 :     BasicNVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a),
+   def f32 :     BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".f32",
                            [(set f32:$dst, (OpNode f32:$a))]>;
 }
 
 multiclass F2_Support_Half<string OpcStr, SDNode OpNode> {
-   def bf16 :      BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+   def bf16 :      BasicNVPTXInst<(outs B16:$dst), (ins B16:$a),
                            OpcStr # ".bf16",
                            [(set bf16:$dst, (OpNode bf16:$a))]>,
                            Requires<[hasSM<80>, hasPTX<70>]>;
-   def bf16x2 :    BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+   def bf16x2 :    BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".bf16x2",
                            [(set v2bf16:$dst, (OpNode v2bf16:$a))]>,
                            Requires<[hasSM<80>, hasPTX<70>]>;
-   def f16_ftz :   BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+   def f16_ftz :   BasicNVPTXInst<(outs B16:$dst), (ins B16:$a),
                            OpcStr # ".ftz.f16",
                            [(set f16:$dst, (OpNode f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
-   def f16x2_ftz : BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+   def f16x2_ftz : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".ftz.f16x2",
                            [(set v2f16:$dst, (OpNode v2f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
-   def f16 :       BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+   def f16 :       BasicNVPTXInst<(outs B16:$dst), (ins B16:$a),
                            OpcStr # ".f16",
                            [(set f16:$dst, (OpNode f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>]>;
-   def f16x2 :     BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+   def f16x2 :     BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".f16x2",
      ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/145255


More information about the llvm-commits mailing list