[llvm] 3665837 - [RISCV] Add support for fixed vector sqrt.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 12 15:36:46 PST 2021


Author: Craig Topper
Date: 2021-02-12T15:33:29-08:00
New Revision: 36658376d5d4103b3828c726f211030ebc4f84b6

URL: https://github.com/llvm/llvm-project/commit/36658376d5d4103b3828c726f211030ebc4f84b6
DIFF: https://github.com/llvm/llvm-project/commit/36658376d5d4103b3828c726f211030ebc4f84b6.diff

LOG: [RISCV] Add support for fixed vector sqrt.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 06bdbfd75f19..1ce754c37214 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -577,6 +577,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::FMUL, VT, Custom);
         setOperationAction(ISD::FDIV, VT, Custom);
         setOperationAction(ISD::FNEG, VT, Custom);
+        setOperationAction(ISD::FSQRT, VT, Custom);
         setOperationAction(ISD::FMA, VT, Custom);
       }
     }
@@ -1209,6 +1210,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL);
   case ISD::FNEG:
     return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL);
+  case ISD::FSQRT:
+    return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL);
   case ISD::FMA:
     return lowerToScalableOp(Op, DAG, RISCVISD::FMA_VL);
   case ISD::SMIN:
@@ -4739,6 +4742,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FMUL_VL)
   NODE_NAME_CASE(FDIV_VL)
   NODE_NAME_CASE(FNEG_VL)
+  NODE_NAME_CASE(FSQRT_VL)
   NODE_NAME_CASE(FMA_VL)
   NODE_NAME_CASE(SMIN_VL)
   NODE_NAME_CASE(SMAX_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 78177c8451cb..edb14c60bf9a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -162,6 +162,7 @@ enum NodeType : unsigned {
   FMUL_VL,
   FDIV_VL,
   FNEG_VL,
+  FSQRT_VL,
   FMA_VL,
   SMIN_VL,
   SMAX_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index a9cb535d8901..bc45922b89eb 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -51,28 +51,29 @@ def riscv_vle_vl : SDNode<"RISCVISD::VLE_VL", SDT_RISCVVLE_VL,
 def riscv_vse_vl : SDNode<"RISCVISD::VSE_VL", SDT_RISCVVSE_VL,
                           [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
 
-def riscv_add_vl  : SDNode<"RISCVISD::ADD_VL",  SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
-def riscv_sub_vl  : SDNode<"RISCVISD::SUB_VL",  SDT_RISCVIntBinOp_VL>;
-def riscv_mul_vl  : SDNode<"RISCVISD::MUL_VL",  SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
-def riscv_and_vl  : SDNode<"RISCVISD::AND_VL",  SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
-def riscv_or_vl   : SDNode<"RISCVISD::OR_VL",   SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
-def riscv_xor_vl  : SDNode<"RISCVISD::XOR_VL",  SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
-def riscv_sdiv_vl : SDNode<"RISCVISD::SDIV_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_srem_vl : SDNode<"RISCVISD::SREM_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_udiv_vl : SDNode<"RISCVISD::UDIV_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_urem_vl : SDNode<"RISCVISD::UREM_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_shl_vl  : SDNode<"RISCVISD::SHL_VL",  SDT_RISCVIntBinOp_VL>;
-def riscv_sra_vl  : SDNode<"RISCVISD::SRA_VL",  SDT_RISCVIntBinOp_VL>;
-def riscv_srl_vl  : SDNode<"RISCVISD::SRL_VL",  SDT_RISCVIntBinOp_VL>;
-def riscv_smin_vl : SDNode<"RISCVISD::SMIN_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_smax_vl : SDNode<"RISCVISD::SMAX_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_umin_vl : SDNode<"RISCVISD::UMIN_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_umax_vl : SDNode<"RISCVISD::UMAX_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
-def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
-def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
-def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>;
-def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>;
+def riscv_add_vl   : SDNode<"RISCVISD::ADD_VL",   SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
+def riscv_sub_vl   : SDNode<"RISCVISD::SUB_VL",   SDT_RISCVIntBinOp_VL>;
+def riscv_mul_vl   : SDNode<"RISCVISD::MUL_VL",   SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
+def riscv_and_vl   : SDNode<"RISCVISD::AND_VL",   SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
+def riscv_or_vl    : SDNode<"RISCVISD::OR_VL",    SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
+def riscv_xor_vl   : SDNode<"RISCVISD::XOR_VL",   SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
+def riscv_sdiv_vl  : SDNode<"RISCVISD::SDIV_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_srem_vl  : SDNode<"RISCVISD::SREM_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_udiv_vl  : SDNode<"RISCVISD::UDIV_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_urem_vl  : SDNode<"RISCVISD::UREM_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_shl_vl   : SDNode<"RISCVISD::SHL_VL",   SDT_RISCVIntBinOp_VL>;
+def riscv_sra_vl   : SDNode<"RISCVISD::SRA_VL",   SDT_RISCVIntBinOp_VL>;
+def riscv_srl_vl   : SDNode<"RISCVISD::SRL_VL",   SDT_RISCVIntBinOp_VL>;
+def riscv_smin_vl  : SDNode<"RISCVISD::SMIN_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_smax_vl  : SDNode<"RISCVISD::SMAX_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_umin_vl  : SDNode<"RISCVISD::UMIN_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_umax_vl  : SDNode<"RISCVISD::UMAX_VL",  SDT_RISCVIntBinOp_VL>;
+def riscv_fadd_vl  : SDNode<"RISCVISD::FADD_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
+def riscv_fsub_vl  : SDNode<"RISCVISD::FSUB_VL",  SDT_RISCVFPBinOp_VL>;
+def riscv_fmul_vl  : SDNode<"RISCVISD::FMUL_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
+def riscv_fdiv_vl  : SDNode<"RISCVISD::FDIV_VL",  SDT_RISCVFPBinOp_VL>;
+def riscv_fneg_vl  : SDNode<"RISCVISD::FNEG_VL",  SDT_RISCVFPUnOp_VL>;
+def riscv_fsqrt_vl : SDNode<"RISCVISD::FSQRT_VL", SDT_RISCVFPUnOp_VL>;
 
 def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
                                               SDTCisSameAs<0, 2>,
@@ -440,9 +441,15 @@ foreach vti = AllFloatVectors in {
                  GPR:$vl, vti.SEW)>;
 }
 
-// 14.12. Vector Floating-Point Sign-Injection Instructions
-// Handle fneg with VFSGNJN using the same input for both operands.
 foreach vti = AllFloatVectors in {
+  // 14.8. Vector Floating-Point Square-Root Instruction
+  def : Pat<(riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask true_mask),
+                            (XLenVT (VLOp GPR:$vl))),
+            (!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX)
+                 vti.RegClass:$rs2, GPR:$vl, vti.SEW)>;
+
+  // 14.12. Vector Floating-Point Sign-Injection Instructions
+  // Handle fneg with VFSGNJN using the same input for both operands.
   def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask true_mask),
                            (XLenVT (VLOp GPR:$vl))),
             (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX)

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
index 7407aa8aa5b7..2c54c4690b08 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
@@ -253,6 +253,54 @@ define void @fneg_v2f64(<2 x double>* %x) {
   ret void
 }
 
+define void @sqrt_v8f16(<8 x half>* %x) {
+; CHECK-LABEL: sqrt_v8f16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a1, zero, 8
+; CHECK-NEXT:    vsetvli a1, a1, e16,m1,ta,mu
+; CHECK-NEXT:    vle16.v v25, (a0)
+; CHECK-NEXT:    vfsqrt.v v25, v25
+; CHECK-NEXT:    vse16.v v25, (a0)
+; CHECK-NEXT:    ret
+  %a = load <8 x half>, <8 x half>* %x
+  %b = call <8 x half> @llvm.sqrt.v8f16(<8 x half> %a)
+  store <8 x half> %b, <8 x half>* %x
+  ret void
+}
+declare <8 x half> @llvm.sqrt.v8f16(<8 x half>)
+
+define void @sqrt_v4f32(<4 x float>* %x) {
+; CHECK-LABEL: sqrt_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a1, zero, 4
+; CHECK-NEXT:    vsetvli a1, a1, e32,m1,ta,mu
+; CHECK-NEXT:    vle32.v v25, (a0)
+; CHECK-NEXT:    vfsqrt.v v25, v25
+; CHECK-NEXT:    vse32.v v25, (a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x float>, <4 x float>* %x
+  %b = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %a)
+  store <4 x float> %b, <4 x float>* %x
+  ret void
+}
+declare <4 x float> @llvm.sqrt.v4f32(<4 x float>)
+
+define void @sqrt_v2f64(<2 x double>* %x) {
+; CHECK-LABEL: sqrt_v2f64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a1, zero, 2
+; CHECK-NEXT:    vsetvli a1, a1, e64,m1,ta,mu
+; CHECK-NEXT:    vle64.v v25, (a0)
+; CHECK-NEXT:    vfsqrt.v v25, v25
+; CHECK-NEXT:    vse64.v v25, (a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x double>, <2 x double>* %x
+  %b = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %a)
+  store <2 x double> %b, <2 x double>* %x
+  ret void
+}
+declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)
+
 define void @fma_v8f16(<8 x half>* %x, <8 x half>* %y, <8 x half>* %z) {
 ; CHECK-LABEL: fma_v8f16:
 ; CHECK:       # %bb.0:


        


More information about the llvm-commits mailing list