[llvm] 55eb93b - [RISCV] Remove RISCVISD::FP_EXTEND_BF16. (#106939)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 2 10:14:07 PDT 2024


Author: Craig Topper
Date: 2024-09-02T10:14:04-07:00
New Revision: 55eb93b2688de99ada14c71804af99502276ac79

URL: https://github.com/llvm/llvm-project/commit/55eb93b2688de99ada14c71804af99502276ac79
DIFF: https://github.com/llvm/llvm-project/commit/55eb93b2688de99ada14c71804af99502276ac79.diff

LOG: [RISCV] Remove RISCVISD::FP_EXTEND_BF16. (#106939)

I don't think we need this node. We can isel fp_extend directly.
fp_extend to f64 requires two instructions, but we can emit them with an
isel pattern.

I have not removed RISCVISD::FP_ROUND_BF16 because f64->bf16 needs more
work to fix the double rounding.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
    llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d02078372b24a2..250d1c60b9f59e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -452,8 +452,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
     setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
     setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
-    setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
-    setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
     setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
     setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
     setOperationAction(ISD::BR_CC, MVT::bf16, Expand);
@@ -6500,18 +6498,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       return SplitVectorOp(Op, DAG);
     return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
   case ISD::FP_EXTEND: {
-    SDLoc DL(Op);
-    EVT VT = Op.getValueType();
-    SDValue Op0 = Op.getOperand(0);
-    EVT Op0VT = Op0.getValueType();
-    if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin())
-      return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
-    if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) {
-      SDValue FloatVal =
-          DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
-      return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal);
-    }
-
     if (!Op.getValueType().isVector())
       return Op;
     return lowerVectorFPExtendOrRoundLike(Op, DAG);
@@ -20463,7 +20449,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(STRICT_FCVT_W_RV64)
   NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
   NODE_NAME_CASE(FP_ROUND_BF16)
-  NODE_NAME_CASE(FP_EXTEND_BF16)
   NODE_NAME_CASE(FROUND)
   NODE_NAME_CASE(FCLASS)
   NODE_NAME_CASE(FSGNJX)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 9ae35173ba0cb3..29a16282ed001d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -117,7 +117,6 @@ enum NodeType : unsigned {
   FCVT_WU_RV64,
 
   FP_ROUND_BF16,
-  FP_EXTEND_BF16,
 
   // Rounds an FP value to its corresponding integer in the same FP format.
   // First operand is the value to round, the second operand is the largest

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 0f435c4ff3d315..f12f82cb159529 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -677,8 +677,7 @@ multiclass VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<SDNode op,
       VPatWidenBinaryFPSDNode_WV_WF_RM<op, instruction_name>;
 
 multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
-                                            list <VTypeInfoToWide> vtiToWtis,
-                                            PatFrags extop> {
+                                            list <VTypeInfoToWide> vtiToWtis> {
   foreach vtiToWti = vtiToWtis in {
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
@@ -702,7 +701,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
                    FRM_DYN,
                    vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
       def : Pat<(fma (wti.Vector (SplatFPOp
-                                      (extop (vti.Scalar vti.ScalarRegClass:$rs1)))),
+                                      (fpext_oneuse (vti.Scalar vti.ScalarRegClass:$rs1)))),
                      (wti.Vector (riscv_fpextend_vl_oneuse
                                       (vti.Vector vti.RegClass:$rs2),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
@@ -1290,11 +1289,9 @@ foreach fvti = AllFloatVectors in {
 
 // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
 defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC",
-                                        AllWidenableFloatVectors,
-                                        fpext_oneuse>;
+                                        AllWidenableFloatVectors>;
 defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16",
-                                        AllWidenableBFloatToFloatVectors,
-                                        riscv_fpextend_bf16_oneuse>;
+                                        AllWidenableBFloatToFloatVectors>;
 defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
 defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
 defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
index 88b66e7fc49aad..bf6272317fda4d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
@@ -19,17 +19,9 @@
 
 def SDT_RISCVFP_ROUND_BF16
     : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>;
-def SDT_RISCVFP_EXTEND_BF16
-    : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>;
 
 def riscv_fpround_bf16
     : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>;
-def riscv_fpextend_bf16
-    : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>;
-def riscv_fpextend_bf16_oneuse : PatFrag<(ops node:$A),
-                                         (riscv_fpextend_bf16 node:$A), [{
-  return N->hasOneUse();
-}]>;
 
 //===----------------------------------------------------------------------===//
 // Instructions
@@ -57,7 +49,7 @@ def : StPat<store, FSH, FPR16, bf16>;
 // f32 -> bf16, bf16 -> f32
 def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)),
           (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>;
-def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)),
+def : Pat<(fpextend (bf16 FPR16:$rs1)),
           (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>;
 
 // Moves (no conversion)
@@ -87,3 +79,9 @@ def : Pat<(i64 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_BF16 $rs1
 def : Pat<(bf16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
 def : Pat<(bf16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
 }
+
+let Predicates = [HasStdExtZfbfmin, HasStdExtD] in {
+// bf16 -> f64
+def : Pat<(fpextend (bf16 FPR16:$rs1)),
+          (FCVT_D_S (FCVT_S_BF16 FPR16:$rs1, FRM_DYN), FRM_RNE)>;
+}


        


More information about the llvm-commits mailing list