[llvm] [LoongArch] Make ISD::FSQRT a legal operation with lsx/lasx feature (PR #74795)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 7 17:59:34 PST 2023


https://github.com/wangleiat created https://github.com/llvm/llvm-project/pull/74795

And add some patterns:
1. (fdiv 1.0, vector)
2. (fdiv 1.0, (fsqrt vector))

>From 10436c83af68da01ab3295b3fb0570d61634d384 Mon Sep 17 00:00:00 2001
From: wanglei <wanglei at loongson.cn>
Date: Fri, 8 Dec 2023 09:25:46 +0800
Subject: [PATCH] [LoongArch] Make ISD::FSQRT a legal operation with lsx/lasx
 feature

And add some patterns:
1. (fdiv 1.0, vector)
2. (fdiv 1.0, (fsqrt vector))
---
 .../LoongArch/LoongArchISelLowering.cpp       |  2 +
 .../LoongArch/LoongArchLASXInstrInfo.td       | 22 +++++++
 .../Target/LoongArch/LoongArchLSXInstrInfo.td | 45 +++++++++++++
 llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll     | 65 +++++++++++++++++++
 .../LoongArch/lasx/ir-instruction/fdiv.ll     | 29 +++++++++
 llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll      | 65 +++++++++++++++++++
 .../LoongArch/lsx/ir-instruction/fdiv.ll      | 29 +++++++++
 7 files changed, 257 insertions(+)
 create mode 100644 llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll
 create mode 100644 llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index e9e29c2dc8929..e4b1d2d30516a 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -269,6 +269,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction({ISD::FADD, ISD::FSUB}, VT, Legal);
       setOperationAction({ISD::FMUL, ISD::FDIV}, VT, Legal);
       setOperationAction(ISD::FMA, VT, Legal);
+      setOperationAction(ISD::FSQRT, VT, Legal);
       setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
                          ISD::SETUGE, ISD::SETUGT},
                         VT, Expand);
@@ -309,6 +310,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction({ISD::FADD, ISD::FSUB}, VT, Legal);
       setOperationAction({ISD::FMUL, ISD::FDIV}, VT, Legal);
       setOperationAction(ISD::FMA, VT, Legal);
+      setOperationAction(ISD::FSQRT, VT, Legal);
       setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
                          ISD::SETUGE, ISD::SETUGT},
                         VT, Expand);
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index a9bf65c6840d5..55b90f4450c0d 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -1092,6 +1092,13 @@ multiclass PatXr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LASX256:$xj)>;
 }
 
+multiclass PatXrF<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(v8f32 (OpNode (v8f32 LASX256:$xj))),
+            (!cast<LAInst>(Inst#"_S") LASX256:$xj)>;
+  def : Pat<(v4f64 (OpNode (v4f64 LASX256:$xj))),
+            (!cast<LAInst>(Inst#"_D") LASX256:$xj)>;
+}
+
 multiclass PatXrXr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v32i8 LASX256:$xj), (v32i8 LASX256:$xk)),
             (!cast<LAInst>(Inst#"_B") LASX256:$xj, LASX256:$xk)>;
@@ -1448,6 +1455,21 @@ def : Pat<(fma v8f32:$xj, v8f32:$xk, v8f32:$xa),
 def : Pat<(fma v4f64:$xj, v4f64:$xk, v4f64:$xa),
           (XVFMADD_D v4f64:$xj, v4f64:$xk, v4f64:$xa)>;
 
+// XVFSQRT_{S/D}
+defm : PatXrF<fsqrt, "XVFSQRT">;
+
+// XVRECIP_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj),
+          (XVFRECIP_S v8f32:$xj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, v4f64:$xj),
+          (XVFRECIP_D v4f64:$xj)>;
+
+// XVFRSQRT_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, (fsqrt v8f32:$xj)),
+          (XVFRSQRT_S v8f32:$xj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, (fsqrt v4f64:$xj)),
+          (XVFRSQRT_D v4f64:$xj)>;
+
 // XVSEQ[I]_{B/H/W/D}
 defm : PatCCXrSimm5<SETEQ, "XVSEQI">;
 defm : PatCCXrXr<SETEQ, "XVSEQ">;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index ff21c6681271e..8ad0c5904f258 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -95,6 +95,29 @@ def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector),
          Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 63;
 }]>;
 
+def vsplatf32_fpimm_eq_1
+  : PatFrags<(ops), [(bitconvert (v4i32 (build_vector))),
+                     (bitconvert (v8i32 (build_vector)))], [{
+  APInt Imm;
+  EVT EltTy = N->getValueType(0).getVectorElementType();
+  N = N->getOperand(0).getNode();
+
+  return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
+         Imm.getBitWidth() == EltTy.getSizeInBits() &&
+         Imm == APFloat(+1.0f).bitcastToAPInt();
+}]>;
+def vsplatf64_fpimm_eq_1
+  : PatFrags<(ops), [(bitconvert (v2i64 (build_vector))),
+                     (bitconvert (v4i64 (build_vector)))], [{
+  APInt Imm;
+  EVT EltTy = N->getValueType(0).getVectorElementType();
+  N = N->getOperand(0).getNode();
+
+  return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
+         Imm.getBitWidth() == EltTy.getSizeInBits() &&
+         Imm == APFloat(+1.0).bitcastToAPInt();
+}]>;
+
 def vsplati8imm7   : PatFrag<(ops node:$reg),
                              (and node:$reg, vsplati8_imm_eq_7)>;
 def vsplati16imm15 : PatFrag<(ops node:$reg),
@@ -1173,6 +1196,13 @@ multiclass PatVr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LSX128:$vj)>;
 }
 
+multiclass PatVrF<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(v4f32 (OpNode (v4f32 LSX128:$vj))),
+            (!cast<LAInst>(Inst#"_S") LSX128:$vj)>;
+  def : Pat<(v2f64 (OpNode (v2f64 LSX128:$vj))),
+            (!cast<LAInst>(Inst#"_D") LSX128:$vj)>;
+}
+
 multiclass PatVrVr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
             (!cast<LAInst>(Inst#"_B") LSX128:$vj, LSX128:$vk)>;
@@ -1525,6 +1555,21 @@ def : Pat<(fma v4f32:$vj, v4f32:$vk, v4f32:$va),
 def : Pat<(fma v2f64:$vj, v2f64:$vk, v2f64:$va),
           (VFMADD_D v2f64:$vj, v2f64:$vk, v2f64:$va)>;
 
+// VFSQRT_{S/D}
+defm : PatVrF<fsqrt, "VFSQRT">;
+
+// VFRECIP_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj),
+          (VFRECIP_S v4f32:$vj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, v2f64:$vj),
+          (VFRECIP_D v2f64:$vj)>;
+
+// VFRSQRT_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, (fsqrt v4f32:$vj)),
+          (VFRSQRT_S v4f32:$vj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, (fsqrt v2f64:$vj)),
+          (VFRSQRT_D v2f64:$vj)>;
+
 // VSEQ[I]_{B/H/W/D}
 defm : PatCCVrSimm5<SETEQ, "VSEQI">;
 defm : PatCCVrVr<SETEQ, "VSEQ">;
diff --git a/llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll b/llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll
new file mode 100644
index 0000000000000..c4a881bdeae9f
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc --mtriple=loongarch64 --mattr=+lasx < %s | FileCheck %s
+
+;; fsqrt
+define void @sqrt_v8f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfsqrt.s $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <8 x float>, ptr %a0, align 16
+  %sqrt = call <8 x float> @llvm.sqrt.v8f32 (<8 x float> %v0)
+  store <8 x float> %sqrt, ptr %res, align 16
+  ret void
+}
+
+define void @sqrt_v4f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v4f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfsqrt.d $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x double>, ptr %a0, align 16
+  %sqrt = call <4 x double> @llvm.sqrt.v4f64 (<4 x double> %v0)
+  store <4 x double> %sqrt, ptr %res, align 16
+  ret void
+}
+
+;; 1.0 / (fsqrt vec)
+define void @one_div_sqrt_v8f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrsqrt.s $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <8 x float>, ptr %a0, align 16
+  %sqrt = call <8 x float> @llvm.sqrt.v8f32 (<8 x float> %v0)
+  %div = fdiv <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
+  store <8 x float> %div, ptr %res, align 16
+  ret void
+}
+
+define void @one_div_sqrt_v4f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v4f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrsqrt.d $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x double>, ptr %a0, align 16
+  %sqrt = call <4 x double> @llvm.sqrt.v4f64 (<4 x double> %v0)
+  %div = fdiv <4 x double> <double 1.0, double 1.0, double 1.0, double 1.0>, %sqrt
+  store <4 x double> %div, ptr %res, align 16
+  ret void
+}
+
+declare <8 x float> @llvm.sqrt.v8f32(<8 x float>)
+declare <4 x double> @llvm.sqrt.v4f64(<4 x double>)
diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
index 284121a79a492..6004565b0b784 100644
--- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
+++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
@@ -32,3 +32,32 @@ entry:
   store <4 x double> %v2, ptr %res
   ret void
 }
+
+;; 1.0 / vec
+define void @one_fdiv_v8f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrecip.s $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <8 x float>, ptr %a0
+  %div = fdiv <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %v0
+  store <8 x float> %div, ptr %res
+  ret void
+}
+
+define void @one_fdiv_v4f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v4f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrecip.d $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x double>, ptr %a0
+  %div = fdiv <4 x double> <double 1.0, double 1.0, double 1.0, double 1.0>, %v0
+  store <4 x double> %div, ptr %res
+  ret void
+}
diff --git a/llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll b/llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll
new file mode 100644
index 0000000000000..a57bc1ca0e948
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc --mtriple=loongarch64 --mattr=+lsx < %s | FileCheck %s
+
+;; fsqrt
+define void @sqrt_v4f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v4f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfsqrt.s $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x float>, ptr %a0, align 16
+  %sqrt = call <4 x float> @llvm.sqrt.v4f32 (<4 x float> %v0)
+  store <4 x float> %sqrt, ptr %res, align 16
+  ret void
+}
+
+define void @sqrt_v2f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v2f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfsqrt.d $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <2 x double>, ptr %a0, align 16
+  %sqrt = call <2 x double> @llvm.sqrt.v2f64 (<2 x double> %v0)
+  store <2 x double> %sqrt, ptr %res, align 16
+  ret void
+}
+
+;; 1.0 / (fsqrt vec)
+define void @one_div_sqrt_v4f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v4f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrsqrt.s $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x float>, ptr %a0, align 16
+  %sqrt = call <4 x float> @llvm.sqrt.v4f32 (<4 x float> %v0)
+  %div = fdiv <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
+  store <4 x float> %div, ptr %res, align 16
+  ret void
+}
+
+define void @one_div_sqrt_v2f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v2f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrsqrt.d $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <2 x double>, ptr %a0, align 16
+  %sqrt = call <2 x double> @llvm.sqrt.v2f64 (<2 x double> %v0)
+  %div = fdiv <2 x double> <double 1.0, double 1.0>, %sqrt
+  store <2 x double> %div, ptr %res, align 16
+  ret void
+}
+
+declare <4 x float> @llvm.sqrt.v4f32(<4 x float>)
+declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)
diff --git a/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll b/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
index eb7c8bd9616ec..5f1ee9e4d212e 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
@@ -32,3 +32,32 @@ entry:
   store <2 x double> %v2, ptr %res
   ret void
 }
+
+;; 1.0 / vec
+define void @one_fdiv_v4f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v4f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrecip.s $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x float>, ptr %a0
+  %div = fdiv <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %v0
+  store <4 x float> %div, ptr %res
+  ret void
+}
+
+define void @one_fdiv_v2f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v2f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrecip.d $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <2 x double>, ptr %a0
+  %div = fdiv <2 x double> <double 1.0, double 1.0>, %v0
+  store <2 x double> %div, ptr %res
+  ret void
+}



More information about the llvm-commits mailing list