[llvm] [NVPTX] Rework and cleanup FTZ ISel (PR #146410)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 9 09:48:05 PDT 2025


================
@@ -1266,86 +1152,73 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b),
   return getDivF32Level(N) == NVPTX::DivPrecisionLevel::IEEE754;
 }]>;
 
-def FRCP32r_prec_ftz :
-  BasicNVPTXInst<(outs B32:$dst),
-                 (ins B32:$b),
-                 "rcp.rn.ftz.f32",
-                 [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>,
-                 Requires<[doF32FTZ]>;
 def FRCP32r_prec :
-  BasicNVPTXInst<(outs B32:$dst),
+  BasicFlagsNVPTXInst<(outs B32:$dst),
                  (ins B32:$b),
-                 "rcp.rn.f32",
-                 [(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
+                 (ins FTZFlag:$ftz),
+                 "rcp.rn$ftz.f32",
+                 [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>;
 //
 // F32 Accurate division
 //
-def FDIV32rr_prec_ftz :
-  BasicNVPTXInst<(outs B32:$dst),
-                 (ins B32:$a, B32:$b),
-                 "div.rn.ftz.f32",
-                 [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>,
-                 Requires<[doF32FTZ]>;
-def FDIV32ri_prec_ftz :
-  BasicNVPTXInst<(outs B32:$dst),
-                 (ins B32:$a, f32imm:$b),
-                 "div.rn.ftz.f32",
-                 [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>,
-                 Requires<[doF32FTZ]>;
 def FDIV32rr_prec :
-  BasicNVPTXInst<(outs B32:$dst),
+  BasicFlagsNVPTXInst<(outs B32:$dst),
                  (ins B32:$a, B32:$b),
-                 "div.rn.f32",
-                 [(set f32:$dst, (fdiv f32:$a, f32:$b))]>;
+                 (ins FTZFlag:$ftz),
+                 "div.rn$ftz.f32",
+                 [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>;
 def FDIV32ri_prec :
-  BasicNVPTXInst<(outs B32:$dst),
+  BasicFlagsNVPTXInst<(outs B32:$dst),
                  (ins B32:$a, f32imm:$b),
-                 "div.rn.f32",
-                 [(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>;
+                 (ins FTZFlag:$ftz),
+                 "div.rn$ftz.f32",
+                 [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>;
+
+def : Pat<(fdiv f32imm_1, f32:$b), (FRCP32r_prec $b, NoFTZ)>;
+def : Pat<(fdiv f32:$a, f32:$b), (FDIV32rr_prec $a, $b, NoFTZ)>;
+def : Pat<(fdiv f32:$a, fpimm:$b), (FDIV32ri_prec $a, fpimm:$b, NoFTZ)>;
 
 //
 // FMA
 //
 
-multiclass FMA<string asmstr, RegTyInfo t, list<Predicate> Preds = []> {
-  def rrr : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
-                      asmstr,
-                      [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
-                      Requires<Preds>;
-
-  if t.SupportsImm then {
-    def rri : BasicNVPTXInst<(outs t.RC:$dst),
-                        (ins t.RC:$a, t.RC:$b, t.Imm:$c),
-                        asmstr,
-                        [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
-                        Requires<Preds>;
-    def rir : BasicNVPTXInst<(outs t.RC:$dst),
-                        (ins t.RC:$a, t.Imm:$b, t.RC:$c),
-                        asmstr,
-                        [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
-                        Requires<Preds>;
-    def rii : BasicNVPTXInst<(outs t.RC:$dst),
-                        (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
-                        asmstr,
-                        [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
-                        Requires<Preds>;
-    def iir : BasicNVPTXInst<(outs t.RC:$dst),
-                        (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
-                        asmstr,
-                        [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
-                        Requires<Preds>;
+multiclass FMA<RegTyInfo t, bit allow_ftz = true, list<Predicate> preds = []> {
+  defvar flag_str = !if(allow_ftz, "$ftz", "");
+  defvar flag_ops = !if(allow_ftz, (ins FTZFlag:$ftz), (ins));
+  defvar op_str = "fma.rn" # flag_str # "." # t.Str;
+
+  let Predicates = preds in {
+    def rrr : BasicFlagsNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
+                        flag_ops, op_str,
+                        [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>;
+
+    if t.SupportsImm then {
+      def rri : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+                          (ins t.RC:$a, t.RC:$b, t.Imm:$c),
+                          flag_ops, op_str,
+                          [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>;
+      def rir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+                          (ins t.RC:$a, t.Imm:$b, t.RC:$c),
+                          flag_ops, op_str,
+                          [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>;
+      def rii : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+                          (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
+                          flag_ops, op_str,
+                          [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>;
+      def iir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+                          (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
+                          flag_ops, op_str,
+                          [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>;
+    }
   }
 }
 
-defm FMA16_ftz    : FMA<"fma.rn.ftz.f16", F16RT, [useFP16Math, doF32FTZ]>;
-defm FMA16        : FMA<"fma.rn.f16", F16RT, [useFP16Math]>;
-defm FMA16x2_ftz  : FMA<"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math, doF32FTZ]>;
-defm FMA16x2      : FMA<"fma.rn.f16x2", F16X2RT, [useFP16Math]>;
-defm BFMA16       : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
-defm BFMA16x2     : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
-defm FMA32_ftz    : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
-defm FMA32        : FMA<"fma.rn.f32", F32RT>;
-defm FMA64        : FMA<"fma.rn.f64", F64RT>;
+defm FMA_F16    : FMA<F16RT,    allow_ftz = true, preds = [useFP16Math]>;
----------------
AlexMaclean wrote:

I think there are enough exceptions that it is best not to think of allow_ftz as part of the float register classes. There are enough quirks in various different instructions that it seems cleanest and simplest to define this on a per-instruction basis.

https://github.com/llvm/llvm-project/pull/146410


More information about the llvm-commits mailing list