[llvm] [RISCV] Match vector fp-int convert intrinsics with specific RTZ rounding mode to the rtz variants (PR #98120)

Jianjian Guan via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 9 00:23:13 PDT 2024


https://github.com/jacquesguan created https://github.com/llvm/llvm-project/pull/98120

None

>From 2f81c0b48ffddf27b06b1327d0f30d8e3e01ebfa Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Tue, 9 Jul 2024 15:15:11 +0800
Subject: [PATCH] [RISCV] Match vector fp-int convert intrinsics with specific
 RTZ rounding mode to the rtz variants

---
 .../Target/RISCV/RISCVInstrInfoVPseudos.td    | 122 ++++++++++++++++++
 llvm/test/CodeGen/RISCV/rvv/vfcvt-x-f.ll      |  16 +++
 llvm/test/CodeGen/RISCV/rvv/vfcvt-xu-f.ll     |  16 +++
 llvm/test/CodeGen/RISCV/rvv/vfncvt-x-f.ll     |  17 +++
 llvm/test/CodeGen/RISCV/rvv/vfncvt-xu-f.ll    |  16 +++
 llvm/test/CodeGen/RISCV/rvv/vfwcvt-x-f.ll     |  16 +++
 llvm/test/CodeGen/RISCV/rvv/vfwcvt-xu-f.ll    |  16 +++
 7 files changed, 219 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 42d6b03968d74..d72390b7c14b5 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -3957,6 +3957,28 @@ class VPatUnaryNoMaskRoundingMode<string intrinsic_name,
                    (XLenVT timm:$round),
                    GPR:$vl, log2sew, TU_MU)>;
 
+class VPatUnaryNoMaskRTZ<string intrinsic_name,
+                         string inst,
+                         string kind,
+                         ValueType result_type,
+                         ValueType op2_type,
+                         int log2sew,
+                         LMULInfo vlmul,
+                         VReg result_reg_class,
+                         VReg op2_reg_class,
+                         bit isSEWAware = 0> :
+  Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   (XLenVT 0b001),
+                   VLOpFrag)),
+                   (!cast<Instruction>(
+                      !if(isSEWAware,
+                          inst#"_"#kind#"_"#vlmul.MX#"_E"#!shl(1, log2sew),
+                          inst#"_"#kind#"_"#vlmul.MX))
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   GPR:$vl, log2sew, TU_MU)>;
 
 class VPatUnaryMask<string intrinsic_name,
                     string inst,
@@ -4009,6 +4031,31 @@ class VPatUnaryMaskRoundingMode<string intrinsic_name,
                    (XLenVT timm:$round),
                    GPR:$vl, log2sew, (XLenVT timm:$policy))>;
 
+class VPatUnaryMaskRTZ<string intrinsic_name,
+                       string inst,
+                       string kind,
+                       ValueType result_type,
+                       ValueType op2_type,
+                       ValueType mask_type,
+                       int log2sew,
+                       LMULInfo vlmul,
+                       VReg result_reg_class,
+                       VReg op2_reg_class,
+                       bit isSEWAware = 0> :
+  Pat<(result_type (!cast<Intrinsic>(intrinsic_name#"_mask")
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   (mask_type V0),
+                   (XLenVT 0b001),
+                   VLOpFrag, (XLenVT timm:$policy))),
+                   (!cast<Instruction>(
+                      !if(isSEWAware,
+                          inst#"_"#kind#"_"#vlmul.MX#"_E"#!shl(1, log2sew)#"_MASK",
+                          inst#"_"#kind#"_"#vlmul.MX#"_MASK"))
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   (mask_type V0),
+                   GPR:$vl, log2sew, (XLenVT timm:$policy))>;
 
 class VPatMaskUnaryNoMask<string intrinsic_name,
                           string inst,
@@ -4826,6 +4873,25 @@ multiclass VPatConversionRoundingMode<string intrinsic,
                                   op1_reg_class, isSEWAware>;
 }
 
+multiclass VPatConversionRTZ<string intrinsic,
+                             string inst,
+                             string kind,
+                             ValueType result_type,
+                             ValueType op1_type,
+                             ValueType mask_type,
+                             int log2sew,
+                             LMULInfo vlmul,
+                             VReg result_reg_class,
+                             VReg op1_reg_class,
+                             bit isSEWAware = 0> {
+  def : VPatUnaryNoMaskRTZ<intrinsic, inst, kind, result_type, op1_type,
+                                    log2sew, vlmul, result_reg_class,
+                                    op1_reg_class, isSEWAware>;
+  def : VPatUnaryMaskRTZ<intrinsic, inst, kind, result_type, op1_type,
+                                  mask_type, log2sew, vlmul, result_reg_class,
+                                  op1_reg_class, isSEWAware>;
+}
+
 multiclass VPatBinaryV_VV<string intrinsic, string instruction,
                           list<VTypeInfo> vtilist, bit isSEWAware = 0> {
   foreach vti = vtilist in
@@ -5776,6 +5842,18 @@ multiclass VPatConversionVI_VF_RM<string intrinsic,
   }
 }
 
+multiclass VPatConversionVI_VF_RTZ<string intrinsic, 
+                                           string instruction> {
+  foreach fvti = AllFloatVectors in {
+    defvar ivti = GetIntVTypeInfo<fvti>.Vti;
+    let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
+                                 GetVTypePredicates<ivti>.Predicates) in
+    defm : VPatConversionRTZ<intrinsic, instruction, "V",
+                                              ivti.Vector, fvti.Vector, ivti.Mask, fvti.Log2SEW,
+                                              fvti.LMul, ivti.RegClass, fvti.RegClass>;
+  }
+}
+
 multiclass VPatConversionVF_VI_RM<string intrinsic, string instruction,
                                   bit isSEWAware = 0> {
   foreach fvti = AllFloatVectors in {
@@ -5813,6 +5891,18 @@ multiclass VPatConversionWI_VF_RM<string intrinsic, string instruction> {
   }
 }
 
+multiclass VPatConversionWI_VF_RTZ<string intrinsic, string instruction> {
+  foreach fvtiToFWti = AllWidenableFloatVectors in {
+    defvar fvti = fvtiToFWti.Vti;
+    defvar iwti = GetIntVTypeInfo<fvtiToFWti.Wti>.Vti;
+    let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
+                                 GetVTypePredicates<iwti>.Predicates) in
+    defm : VPatConversionRTZ<intrinsic, instruction, "V",
+                             iwti.Vector, fvti.Vector, iwti.Mask, fvti.Log2SEW,
+                             fvti.LMul, iwti.RegClass, fvti.RegClass>;
+  }
+}
+
 multiclass VPatConversionWF_VI<string intrinsic, string instruction,
                                bit isSEWAware = 0> {
   foreach vtiToWti = AllWidenableIntToFloatVectors in {
@@ -5879,6 +5969,18 @@ multiclass VPatConversionVI_WF_RM <string intrinsic, string instruction> {
   }
 }
 
+multiclass VPatConversionVI_WF_RTZ <string intrinsic, string instruction> {
+  foreach vtiToWti = AllWidenableIntToFloatVectors in {
+    defvar vti = vtiToWti.Vti;
+    defvar fwti = vtiToWti.Wti;
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<fwti>.Predicates) in
+    defm : VPatConversionRTZ<intrinsic, instruction, "W",
+                             vti.Vector, fwti.Vector, vti.Mask, vti.Log2SEW,
+                             vti.LMul, vti.RegClass, fwti.RegClass>;
+  }
+}
+
 multiclass VPatConversionVF_WI_RM <string intrinsic, string instruction,
                                    bit isSEWAware = 0> {
   foreach fvtiToFWti = AllWidenableFloatVectors in {
@@ -5921,6 +6023,20 @@ multiclass VPatConversionVF_WF_RM<string intrinsic, string instruction,
   }
 }
 
+multiclass VPatConversionVF_WF_RTZ<string intrinsic, string instruction,
+                                   list<VTypeInfoToWide> wlist = AllWidenableFloatVectors,
+                                   bit isSEWAware = 0> {
+  foreach fvtiToFWti = wlist in {
+    defvar fvti = fvtiToFWti.Vti;
+    defvar fwti = fvtiToFWti.Wti;
+    let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
+                                 GetVTypePredicates<fwti>.Predicates) in
+    defm : VPatConversionRTZ<intrinsic, instruction, "W",
+                             fvti.Vector, fwti.Vector, fvti.Mask, fvti.Log2SEW,
+                             fvti.LMul, fvti.RegClass, fwti.RegClass, isSEWAware>;
+  }
+}
+
 multiclass VPatConversionVF_WF_BF_RM<string intrinsic, string instruction,
                                      bit isSEWAware = 0> {
   foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
@@ -7153,6 +7269,8 @@ foreach fvti = AllFloatVectors in {
 //===----------------------------------------------------------------------===//
 // 13.17. Single-Width Floating-Point/Integer Type-Convert Instructions
 //===----------------------------------------------------------------------===//
+defm : VPatConversionVI_VF_RTZ<"int_riscv_vfcvt_x_f_v", "PseudoVFCVT_RTZ_X_F">;
+defm : VPatConversionVI_VF_RTZ<"int_riscv_vfcvt_xu_f_v", "PseudoVFCVT_RTZ_XU_F">;
 defm : VPatConversionVI_VF_RM<"int_riscv_vfcvt_x_f_v", "PseudoVFCVT_X_F">;
 defm : VPatConversionVI_VF_RM<"int_riscv_vfcvt_xu_f_v", "PseudoVFCVT_XU_F">;
 defm : VPatConversionVI_VF<"int_riscv_vfcvt_rtz_xu_f_v", "PseudoVFCVT_RTZ_XU_F">;
@@ -7165,6 +7283,8 @@ defm : VPatConversionVF_VI_RM<"int_riscv_vfcvt_f_xu_v", "PseudoVFCVT_F_XU",
 //===----------------------------------------------------------------------===//
 // 13.18. Widening Floating-Point/Integer Type-Convert Instructions
 //===----------------------------------------------------------------------===//
+defm : VPatConversionWI_VF_RTZ<"int_riscv_vfwcvt_xu_f_v", "PseudoVFWCVT_RTZ_XU_F">;
+defm : VPatConversionWI_VF_RTZ<"int_riscv_vfwcvt_x_f_v", "PseudoVFWCVT_RTZ_X_F">;
 defm : VPatConversionWI_VF_RM<"int_riscv_vfwcvt_xu_f_v", "PseudoVFWCVT_XU_F">;
 defm : VPatConversionWI_VF_RM<"int_riscv_vfwcvt_x_f_v", "PseudoVFWCVT_X_F">;
 defm : VPatConversionWI_VF<"int_riscv_vfwcvt_rtz_xu_f_v", "PseudoVFWCVT_RTZ_XU_F">;
@@ -7181,6 +7301,8 @@ defm : VPatConversionWF_VF_BF<"int_riscv_vfwcvtbf16_f_f_v",
 //===----------------------------------------------------------------------===//
 // 13.19. Narrowing Floating-Point/Integer Type-Convert Instructions
 //===----------------------------------------------------------------------===//
+defm : VPatConversionVI_WF_RTZ<"int_riscv_vfncvt_xu_f_w", "PseudoVFNCVT_RTZ_XU_F">;
+defm : VPatConversionVI_WF_RTZ<"int_riscv_vfncvt_x_f_w", "PseudoVFNCVT_RTZ_X_F">;
 defm : VPatConversionVI_WF_RM<"int_riscv_vfncvt_xu_f_w", "PseudoVFNCVT_XU_F">;
 defm : VPatConversionVI_WF_RM<"int_riscv_vfncvt_x_f_w", "PseudoVFNCVT_X_F">;
 defm : VPatConversionVI_WF<"int_riscv_vfncvt_rtz_xu_f_w", "PseudoVFNCVT_RTZ_XU_F">;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfcvt-x-f.ll b/llvm/test/CodeGen/RISCV/rvv/vfcvt-x-f.ll
index 68a85530ea242..582c302dd2a15 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfcvt-x-f.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfcvt-x-f.ll
@@ -693,3 +693,19 @@ entry:
 
   ret <vscale x 8 x i64> %a
 }
+
+define <vscale x 8 x i64> @intrinsic_vfcvt_mask_x.f.v_rtz_nxv8i64_nxv8f64(<vscale x 8 x i64> %0, <vscale x 8 x double> %1, <vscale x 8 x i1> %2, iXLen %3) nounwind {
+; CHECK-LABEL: intrinsic_vfcvt_mask_x.f.v_rtz_nxv8i64_nxv8f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, mu
+; CHECK-NEXT:    vfcvt.rtz.x.f.v v8, v16, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 8 x i64> @llvm.riscv.vfcvt.x.f.v.mask.nxv8i64.nxv8f64(
+    <vscale x 8 x i64> %0,
+    <vscale x 8 x double> %1,
+    <vscale x 8 x i1> %2,
+    iXLen 1, iXLen %3, iXLen 1)
+
+  ret <vscale x 8 x i64> %a
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfcvt-xu-f.ll b/llvm/test/CodeGen/RISCV/rvv/vfcvt-xu-f.ll
index 93716ba7f451c..708b38b8ed116 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfcvt-xu-f.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfcvt-xu-f.ll
@@ -693,3 +693,19 @@ entry:
 
   ret <vscale x 8 x i64> %a
 }
+
+define <vscale x 8 x i64> @intrinsic_vfcvt_mask_xu.f.v_rtz_nxv8i64_nxv8f64(<vscale x 8 x i64> %0, <vscale x 8 x double> %1, <vscale x 8 x i1> %2, iXLen %3) nounwind {
+; CHECK-LABEL: intrinsic_vfcvt_mask_xu.f.v_rtz_nxv8i64_nxv8f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, mu
+; CHECK-NEXT:    vfcvt.rtz.xu.f.v v8, v16, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 8 x i64> @llvm.riscv.vfcvt.xu.f.v.mask.nxv8i64.nxv8f64(
+    <vscale x 8 x i64> %0,
+    <vscale x 8 x double> %1,
+    <vscale x 8 x i1> %2,
+    iXLen 1, iXLen %3, iXLen 1)
+
+  ret <vscale x 8 x i64> %a
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfncvt-x-f.ll b/llvm/test/CodeGen/RISCV/rvv/vfncvt-x-f.ll
index e4b39c655a102..334d5eba03001 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfncvt-x-f.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfncvt-x-f.ll
@@ -708,3 +708,20 @@ entry:
 
   ret <vscale x 8 x i32> %a
 }
+
+define <vscale x 8 x i32> @intrinsic_vfncvt_mask_x.f.w_rtz_nxv8i32_nxv8f64(<vscale x 8 x i32> %0, <vscale x 8 x double> %1, <vscale x 8 x i1> %2, iXLen %3) nounwind {
+; CHECK-LABEL: intrinsic_vfncvt_mask_x.f.w_rtz_nxv8i32_nxv8f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, m4, ta, mu
+; CHECK-NEXT:    vfncvt.rtz.x.f.w v8, v16, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 8 x i32> @llvm.riscv.vfncvt.x.f.w.mask.nxv8i32.nxv8f64(
+    <vscale x 8 x i32> %0,
+    <vscale x 8 x double> %1,
+    <vscale x 8 x i1> %2,
+    iXLen 1, iXLen %3, iXLen 1)
+
+  ret <vscale x 8 x i32> %a
+}
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfncvt-xu-f.ll b/llvm/test/CodeGen/RISCV/rvv/vfncvt-xu-f.ll
index fd922438d05b3..bea99a0e81a34 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfncvt-xu-f.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfncvt-xu-f.ll
@@ -708,3 +708,19 @@ entry:
 
   ret <vscale x 8 x i32> %a
 }
+
+define <vscale x 8 x i32> @intrinsic_vfncvt_mask_xu.f.w_rtz_nxv8i32_nxv8f64(<vscale x 8 x i32> %0, <vscale x 8 x double> %1, <vscale x 8 x i1> %2, iXLen %3) nounwind {
+; CHECK-LABEL: intrinsic_vfncvt_mask_xu.f.w_rtz_nxv8i32_nxv8f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, m4, ta, mu
+; CHECK-NEXT:    vfncvt.rtz.xu.f.w v8, v16, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 8 x i32> @llvm.riscv.vfncvt.xu.f.w.mask.nxv8i32.nxv8f64(
+    <vscale x 8 x i32> %0,
+    <vscale x 8 x double> %1,
+    <vscale x 8 x i1> %2,
+    iXLen 1, iXLen %3, iXLen 1)
+
+  ret <vscale x 8 x i32> %a
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwcvt-x-f.ll b/llvm/test/CodeGen/RISCV/rvv/vfwcvt-x-f.ll
index 23b10250dfa48..9a80e02bbbbb4 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwcvt-x-f.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwcvt-x-f.ll
@@ -426,3 +426,19 @@ entry:
 
   ret <vscale x 8 x i64> %a
 }
+
+define <vscale x 8 x i64> @intrinsic_vfwcvt_mask_x.f.v_rtz_nxv8i64_nxv8f32(<vscale x 8 x i64> %0, <vscale x 8 x float> %1, <vscale x 8 x i1> %2, iXLen %3) nounwind {
+; CHECK-LABEL: intrinsic_vfwcvt_mask_x.f.v_rtz_nxv8i64_nxv8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, m4, ta, mu
+; CHECK-NEXT:    vfwcvt.rtz.x.f.v v8, v16, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 8 x i64> @llvm.riscv.vfwcvt.x.f.v.mask.nxv8i64.nxv8f32(
+    <vscale x 8 x i64> %0,
+    <vscale x 8 x float> %1,
+    <vscale x 8 x i1> %2,
+    iXLen 1, iXLen %3, iXLen 1)
+
+  ret <vscale x 8 x i64> %a
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwcvt-xu-f.ll b/llvm/test/CodeGen/RISCV/rvv/vfwcvt-xu-f.ll
index f6779ec9ba5aa..98caaf91ab3c0 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwcvt-xu-f.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwcvt-xu-f.ll
@@ -426,3 +426,19 @@ entry:
 
   ret <vscale x 8 x i64> %a
 }
+
+define <vscale x 8 x i64> @intrinsic_vfwcvt_mask_xu.f.v_rtz_nxv8i64_nxv8f32(<vscale x 8 x i64> %0, <vscale x 8 x float> %1, <vscale x 8 x i1> %2, iXLen %3) nounwind {
+; CHECK-LABEL: intrinsic_vfwcvt_mask_xu.f.v_rtz_nxv8i64_nxv8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, m4, ta, mu
+; CHECK-NEXT:    vfwcvt.rtz.xu.f.v v8, v16, v0.t
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 8 x i64> @llvm.riscv.vfwcvt.xu.f.v.mask.nxv8i64.nxv8f32(
+    <vscale x 8 x i64> %0,
+    <vscale x 8 x float> %1,
+    <vscale x 8 x i1> %2,
+    iXLen 1, iXLen %3, iXLen 1)
+
+  ret <vscale x 8 x i64> %a
+}



More information about the llvm-commits mailing list