[llvm] [LoongArch] Make ISD::FSQRT a legal operation with lsx/lasx feature (PR #74795)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 7 18:00:05 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-loongarch

Author: wanglei (wangleiat)

<details>
<summary>Changes</summary>

And add some patterns:
1. (fdiv 1.0, vector)
2. (fdiv 1.0, (fsqrt vector))

---
Full diff: https://github.com/llvm/llvm-project/pull/74795.diff


7 Files Affected:

- (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+2) 
- (modified) llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td (+22) 
- (modified) llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td (+45) 
- (added) llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll (+65) 
- (modified) llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll (+29) 
- (added) llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll (+65) 
- (modified) llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll (+29) 


``````````diff
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index e9e29c2dc8929d..e4b1d2d30516a7 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -269,6 +269,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction({ISD::FADD, ISD::FSUB}, VT, Legal);
       setOperationAction({ISD::FMUL, ISD::FDIV}, VT, Legal);
       setOperationAction(ISD::FMA, VT, Legal);
+      setOperationAction(ISD::FSQRT, VT, Legal);
       setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
                          ISD::SETUGE, ISD::SETUGT},
                         VT, Expand);
@@ -309,6 +310,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction({ISD::FADD, ISD::FSUB}, VT, Legal);
       setOperationAction({ISD::FMUL, ISD::FDIV}, VT, Legal);
       setOperationAction(ISD::FMA, VT, Legal);
+      setOperationAction(ISD::FSQRT, VT, Legal);
       setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
                          ISD::SETUGE, ISD::SETUGT},
                         VT, Expand);
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index a9bf65c6840d5c..55b90f4450c0d3 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -1092,6 +1092,13 @@ multiclass PatXr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LASX256:$xj)>;
 }
 
+multiclass PatXrF<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(v8f32 (OpNode (v8f32 LASX256:$xj))),
+            (!cast<LAInst>(Inst#"_S") LASX256:$xj)>;
+  def : Pat<(v4f64 (OpNode (v4f64 LASX256:$xj))),
+            (!cast<LAInst>(Inst#"_D") LASX256:$xj)>;
+}
+
 multiclass PatXrXr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v32i8 LASX256:$xj), (v32i8 LASX256:$xk)),
             (!cast<LAInst>(Inst#"_B") LASX256:$xj, LASX256:$xk)>;
@@ -1448,6 +1455,21 @@ def : Pat<(fma v8f32:$xj, v8f32:$xk, v8f32:$xa),
 def : Pat<(fma v4f64:$xj, v4f64:$xk, v4f64:$xa),
           (XVFMADD_D v4f64:$xj, v4f64:$xk, v4f64:$xa)>;
 
+// XVFSQRT_{S/D}
+defm : PatXrF<fsqrt, "XVFSQRT">;
+
+// XVRECIP_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj),
+          (XVFRECIP_S v8f32:$xj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, v4f64:$xj),
+          (XVFRECIP_D v4f64:$xj)>;
+
+// XVFRSQRT_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, (fsqrt v8f32:$xj)),
+          (XVFRSQRT_S v8f32:$xj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, (fsqrt v4f64:$xj)),
+          (XVFRSQRT_D v4f64:$xj)>;
+
 // XVSEQ[I]_{B/H/W/D}
 defm : PatCCXrSimm5<SETEQ, "XVSEQI">;
 defm : PatCCXrXr<SETEQ, "XVSEQ">;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index ff21c6681271e7..8ad0c5904f2583 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -95,6 +95,29 @@ def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector),
          Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 63;
 }]>;
 
+def vsplatf32_fpimm_eq_1
+  : PatFrags<(ops), [(bitconvert (v4i32 (build_vector))),
+                     (bitconvert (v8i32 (build_vector)))], [{
+  APInt Imm;
+  EVT EltTy = N->getValueType(0).getVectorElementType();
+  N = N->getOperand(0).getNode();
+
+  return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
+         Imm.getBitWidth() == EltTy.getSizeInBits() &&
+         Imm == APFloat(+1.0f).bitcastToAPInt();
+}]>;
+def vsplatf64_fpimm_eq_1
+  : PatFrags<(ops), [(bitconvert (v2i64 (build_vector))),
+                     (bitconvert (v4i64 (build_vector)))], [{
+  APInt Imm;
+  EVT EltTy = N->getValueType(0).getVectorElementType();
+  N = N->getOperand(0).getNode();
+
+  return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
+         Imm.getBitWidth() == EltTy.getSizeInBits() &&
+         Imm == APFloat(+1.0).bitcastToAPInt();
+}]>;
+
 def vsplati8imm7   : PatFrag<(ops node:$reg),
                              (and node:$reg, vsplati8_imm_eq_7)>;
 def vsplati16imm15 : PatFrag<(ops node:$reg),
@@ -1173,6 +1196,13 @@ multiclass PatVr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LSX128:$vj)>;
 }
 
+multiclass PatVrF<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(v4f32 (OpNode (v4f32 LSX128:$vj))),
+            (!cast<LAInst>(Inst#"_S") LSX128:$vj)>;
+  def : Pat<(v2f64 (OpNode (v2f64 LSX128:$vj))),
+            (!cast<LAInst>(Inst#"_D") LSX128:$vj)>;
+}
+
 multiclass PatVrVr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
             (!cast<LAInst>(Inst#"_B") LSX128:$vj, LSX128:$vk)>;
@@ -1525,6 +1555,21 @@ def : Pat<(fma v4f32:$vj, v4f32:$vk, v4f32:$va),
 def : Pat<(fma v2f64:$vj, v2f64:$vk, v2f64:$va),
           (VFMADD_D v2f64:$vj, v2f64:$vk, v2f64:$va)>;
 
+// VFSQRT_{S/D}
+defm : PatVrF<fsqrt, "VFSQRT">;
+
+// VFRECIP_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj),
+          (VFRECIP_S v4f32:$vj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, v2f64:$vj),
+          (VFRECIP_D v2f64:$vj)>;
+
+// VFRSQRT_{S/D}
+def : Pat<(fdiv vsplatf32_fpimm_eq_1, (fsqrt v4f32:$vj)),
+          (VFRSQRT_S v4f32:$vj)>;
+def : Pat<(fdiv vsplatf64_fpimm_eq_1, (fsqrt v2f64:$vj)),
+          (VFRSQRT_D v2f64:$vj)>;
+
 // VSEQ[I]_{B/H/W/D}
 defm : PatCCVrSimm5<SETEQ, "VSEQI">;
 defm : PatCCVrVr<SETEQ, "VSEQ">;
diff --git a/llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll b/llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll
new file mode 100644
index 00000000000000..c4a881bdeae9f1
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc --mtriple=loongarch64 --mattr=+lasx < %s | FileCheck %s
+
+;; fsqrt
+define void @sqrt_v8f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfsqrt.s $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <8 x float>, ptr %a0, align 16
+  %sqrt = call <8 x float> @llvm.sqrt.v8f32 (<8 x float> %v0)
+  store <8 x float> %sqrt, ptr %res, align 16
+  ret void
+}
+
+define void @sqrt_v4f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v4f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfsqrt.d $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x double>, ptr %a0, align 16
+  %sqrt = call <4 x double> @llvm.sqrt.v4f64 (<4 x double> %v0)
+  store <4 x double> %sqrt, ptr %res, align 16
+  ret void
+}
+
+;; 1.0 / (fsqrt vec)
+define void @one_div_sqrt_v8f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrsqrt.s $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <8 x float>, ptr %a0, align 16
+  %sqrt = call <8 x float> @llvm.sqrt.v8f32 (<8 x float> %v0)
+  %div = fdiv <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
+  store <8 x float> %div, ptr %res, align 16
+  ret void
+}
+
+define void @one_div_sqrt_v4f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v4f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrsqrt.d $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x double>, ptr %a0, align 16
+  %sqrt = call <4 x double> @llvm.sqrt.v4f64 (<4 x double> %v0)
+  %div = fdiv <4 x double> <double 1.0, double 1.0, double 1.0, double 1.0>, %sqrt
+  store <4 x double> %div, ptr %res, align 16
+  ret void
+}
+
+declare <8 x float> @llvm.sqrt.v8f32(<8 x float>)
+declare <4 x double> @llvm.sqrt.v4f64(<4 x double>)
diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
index 284121a79a492d..6004565b0b784e 100644
--- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
+++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
@@ -32,3 +32,32 @@ entry:
   store <4 x double> %v2, ptr %res
   ret void
 }
+
+;; 1.0 / vec
+define void @one_fdiv_v8f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v8f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrecip.s $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <8 x float>, ptr %a0
+  %div = fdiv <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %v0
+  store <8 x float> %div, ptr %res
+  ret void
+}
+
+define void @one_fdiv_v4f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v4f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvld $xr0, $a1, 0
+; CHECK-NEXT:    xvfrecip.d $xr0, $xr0
+; CHECK-NEXT:    xvst $xr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x double>, ptr %a0
+  %div = fdiv <4 x double> <double 1.0, double 1.0, double 1.0, double 1.0>, %v0
+  store <4 x double> %div, ptr %res
+  ret void
+}
diff --git a/llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll b/llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll
new file mode 100644
index 00000000000000..a57bc1ca0e9488
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc --mtriple=loongarch64 --mattr=+lsx < %s | FileCheck %s
+
+;; fsqrt
+define void @sqrt_v4f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v4f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfsqrt.s $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x float>, ptr %a0, align 16
+  %sqrt = call <4 x float> @llvm.sqrt.v4f32 (<4 x float> %v0)
+  store <4 x float> %sqrt, ptr %res, align 16
+  ret void
+}
+
+define void @sqrt_v2f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: sqrt_v2f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfsqrt.d $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <2 x double>, ptr %a0, align 16
+  %sqrt = call <2 x double> @llvm.sqrt.v2f64 (<2 x double> %v0)
+  store <2 x double> %sqrt, ptr %res, align 16
+  ret void
+}
+
+;; 1.0 / (fsqrt vec)
+define void @one_div_sqrt_v4f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v4f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrsqrt.s $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x float>, ptr %a0, align 16
+  %sqrt = call <4 x float> @llvm.sqrt.v4f32 (<4 x float> %v0)
+  %div = fdiv <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
+  store <4 x float> %div, ptr %res, align 16
+  ret void
+}
+
+define void @one_div_sqrt_v2f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_div_sqrt_v2f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrsqrt.d $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <2 x double>, ptr %a0, align 16
+  %sqrt = call <2 x double> @llvm.sqrt.v2f64 (<2 x double> %v0)
+  %div = fdiv <2 x double> <double 1.0, double 1.0>, %sqrt
+  store <2 x double> %div, ptr %res, align 16
+  ret void
+}
+
+declare <4 x float> @llvm.sqrt.v4f32(<4 x float>)
+declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)
diff --git a/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll b/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
index eb7c8bd9616ec7..5f1ee9e4d212eb 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
@@ -32,3 +32,32 @@ entry:
   store <2 x double> %v2, ptr %res
   ret void
 }
+
+;; 1.0 / vec
+define void @one_fdiv_v4f32(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v4f32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrecip.s $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <4 x float>, ptr %a0
+  %div = fdiv <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %v0
+  store <4 x float> %div, ptr %res
+  ret void
+}
+
+define void @one_fdiv_v2f64(ptr %res, ptr %a0) nounwind {
+; CHECK-LABEL: one_fdiv_v2f64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vld $vr0, $a1, 0
+; CHECK-NEXT:    vfrecip.d $vr0, $vr0
+; CHECK-NEXT:    vst $vr0, $a0, 0
+; CHECK-NEXT:    ret
+entry:
+  %v0 = load <2 x double>, ptr %a0
+  %div = fdiv <2 x double> <double 1.0, double 1.0>, %v0
+  store <2 x double> %div, ptr %res
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/74795


More information about the llvm-commits mailing list