[llvm] [NVPTX] Rework and cleanup FTZ ISel (PR #146410)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 9 09:48:05 PDT 2025
================
@@ -1266,86 +1152,73 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b),
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::IEEE754;
}]>;
-def FRCP32r_prec_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- "rcp.rn.ftz.f32",
- [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>,
- Requires<[doF32FTZ]>;
def FRCP32r_prec :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$b),
- "rcp.rn.f32",
- [(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
+ (ins FTZFlag:$ftz),
+ "rcp.rn$ftz.f32",
+ [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>;
//
// F32 Accurate division
//
-def FDIV32rr_prec_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- "div.rn.ftz.f32",
- [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>,
- Requires<[doF32FTZ]>;
-def FDIV32ri_prec_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- "div.rn.ftz.f32",
- [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>,
- Requires<[doF32FTZ]>;
def FDIV32rr_prec :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- "div.rn.f32",
- [(set f32:$dst, (fdiv f32:$a, f32:$b))]>;
+ (ins FTZFlag:$ftz),
+ "div.rn$ftz.f32",
+ [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>;
def FDIV32ri_prec :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, f32imm:$b),
- "div.rn.f32",
- [(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>;
+ (ins FTZFlag:$ftz),
+ "div.rn$ftz.f32",
+ [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>;
+
+def : Pat<(fdiv f32imm_1, f32:$b), (FRCP32r_prec $b, NoFTZ)>;
+def : Pat<(fdiv f32:$a, f32:$b), (FDIV32rr_prec $a, $b, NoFTZ)>;
+def : Pat<(fdiv f32:$a, fpimm:$b), (FDIV32ri_prec $a, fpimm:$b, NoFTZ)>;
//
// FMA
//
-multiclass FMA<string asmstr, RegTyInfo t, list<Predicate> Preds = []> {
- def rrr : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
- asmstr,
- [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
- Requires<Preds>;
-
- if t.SupportsImm then {
- def rri : BasicNVPTXInst<(outs t.RC:$dst),
- (ins t.RC:$a, t.RC:$b, t.Imm:$c),
- asmstr,
- [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
- Requires<Preds>;
- def rir : BasicNVPTXInst<(outs t.RC:$dst),
- (ins t.RC:$a, t.Imm:$b, t.RC:$c),
- asmstr,
- [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
- Requires<Preds>;
- def rii : BasicNVPTXInst<(outs t.RC:$dst),
- (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
- asmstr,
- [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
- Requires<Preds>;
- def iir : BasicNVPTXInst<(outs t.RC:$dst),
- (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
- asmstr,
- [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
- Requires<Preds>;
+multiclass FMA<RegTyInfo t, bit allow_ftz = true, list<Predicate> preds = []> {
+ defvar flag_str = !if(allow_ftz, "$ftz", "");
+ defvar flag_ops = !if(allow_ftz, (ins FTZFlag:$ftz), (ins));
+ defvar op_str = "fma.rn" # flag_str # "." # t.Str;
+
+ let Predicates = preds in {
+ def rrr : BasicFlagsNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c),
+ flag_ops, op_str,
+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>;
+
+ if t.SupportsImm then {
+ def rri : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+ (ins t.RC:$a, t.RC:$b, t.Imm:$c),
+ flag_ops, op_str,
+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>;
+ def rir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+ (ins t.RC:$a, t.Imm:$b, t.RC:$c),
+ flag_ops, op_str,
+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>;
+ def rii : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+ (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
+ flag_ops, op_str,
+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>;
+ def iir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
+ (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
+ flag_ops, op_str,
+ [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>;
+ }
}
}
-defm FMA16_ftz : FMA<"fma.rn.ftz.f16", F16RT, [useFP16Math, doF32FTZ]>;
-defm FMA16 : FMA<"fma.rn.f16", F16RT, [useFP16Math]>;
-defm FMA16x2_ftz : FMA<"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math, doF32FTZ]>;
-defm FMA16x2 : FMA<"fma.rn.f16x2", F16X2RT, [useFP16Math]>;
-defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
-defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
-defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
-defm FMA32 : FMA<"fma.rn.f32", F32RT>;
-defm FMA64 : FMA<"fma.rn.f64", F64RT>;
+defm FMA_F16 : FMA<F16RT, allow_ftz = true, preds = [useFP16Math]>;
----------------
AlexMaclean wrote:
I think there are enough exceptions that it is best not to think of allow_ftz as part of the float register classes. There are enough quirks in various different instructions that it seems cleanest and simplest to define this on a per-instruction basis.
https://github.com/llvm/llvm-project/pull/146410
More information about the llvm-commits
mailing list