[llvm] f375ee3 - [RISCV] Add codegen for Zfbfmin instructions

Jun Sha via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 23 19:38:24 PDT 2023


Author: Jun Sha (Joshua)
Date: 2023-07-24T10:37:58+08:00
New Revision: f375ee36c4e1c87b8ed47ef44d3613bd5756f57a

URL: https://github.com/llvm/llvm-project/commit/f375ee36c4e1c87b8ed47ef44d3613bd5756f57a
DIFF: https://github.com/llvm/llvm-project/commit/f375ee36c4e1c87b8ed47ef44d3613bd5756f57a.diff

LOG: [RISCV] Add codegen for Zfbfmin instructions

The implementation in https://reviews.llvm.org/D151313 is done for the circumstance without Zfbfmin. This patch adds codegen support for the 6 instructions provided in Zfbfmin extension.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D153234

Added: 
    llvm/test/CodeGen/RISCV/zfbfmin.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
    llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 67a8ac5b6ee767..f642124e072ccf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -116,6 +116,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   if (Subtarget.hasStdExtZfhOrZfhmin())
     addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
+  if (Subtarget.hasStdExtZfbfmin())
+    addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass);
   if (Subtarget.hasStdExtF())
     addRegisterClass(MVT::f32, &RISCV::FPR32RegClass);
   if (Subtarget.hasStdExtD())
@@ -359,6 +361,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+  
+  if (Subtarget.hasStdExtZfbfmin()) {
+    setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+    setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+    setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
+    setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
+    setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
+  }
 
   if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
     if (Subtarget.hasStdExtZfhOrZhinx()) {
@@ -4768,6 +4779,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0);
       return FPConv;
     }
+    if (VT == MVT::bf16 && Op0VT == MVT::i16 &&
+        Subtarget.hasStdExtZfbfmin()) {
+      SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
+      SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0);
+      return FPConv;
+    }
     if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() &&
         Subtarget.hasStdExtFOrZfinx()) {
       SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0);
@@ -4931,11 +4948,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     }
     return SDValue();
   }
-  case ISD::FP_EXTEND:
-  case ISD::FP_ROUND:
+  case ISD::FP_EXTEND: {
+    SDLoc DL(Op);
+    EVT VT = Op.getValueType();
+    SDValue Op0 = Op.getOperand(0);
+    EVT Op0VT = Op0.getValueType();
+    if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin())
+      return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
+    if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) {
+      SDValue FloatVal =
+          DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
+      return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal);
+    }
+
+    if (!Op.getValueType().isVector())
+      return Op;
+    return lowerVectorFPExtendOrRoundLike(Op, DAG);
+  }
+  case ISD::FP_ROUND: {
+    SDLoc DL(Op);
+    EVT VT = Op.getValueType();
+    SDValue Op0 = Op.getOperand(0);
+    EVT Op0VT = Op0.getValueType();
+    if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin())
+      return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0);
+    if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() &&
+        Subtarget.hasStdExtDOrZdinx()) {
+      SDValue FloatVal =
+          DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0,
+                      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+      return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal);
+    }
+
     if (!Op.getValueType().isVector())
       return Op;
     return lowerVectorFPExtendOrRoundLike(Op, DAG);
+  }
   case ISD::STRICT_FP_ROUND:
   case ISD::STRICT_FP_EXTEND:
     return lowerStrictFPExtendOrRoundLike(Op, DAG);
@@ -9926,6 +9974,10 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
         Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
       SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
+    } else if (VT == MVT::i16 && Op0VT == MVT::bf16 &&
+        Subtarget.hasStdExtZfbfmin()) {
+      SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
+      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
     } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() &&
                Subtarget.hasStdExtFOrZfinx()) {
       SDValue FPConv =
@@ -14867,7 +14919,8 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   // similar local variables rather than directly checking against the target
   // ABI.
 
-  if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) {
+  if (UseGPRForF16_F32 &&
+      (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) {
     LocVT = XLenVT;
     LocInfo = CCValAssign::BCvt;
   } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) {
@@ -14960,7 +15013,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   unsigned StoreSizeBytes = XLen / 8;
   Align StackAlign = Align(XLen / 8);
 
-  if (ValVT == MVT::f16 && !UseGPRForF16_F32)
+  if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32)
     Reg = State.AllocateReg(ArgFPR16s);
   else if (ValVT == MVT::f32 && !UseGPRForF16_F32)
     Reg = State.AllocateReg(ArgFPR32s);
@@ -15117,8 +15170,9 @@ static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val,
       Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget);
     break;
   case CCValAssign::BCvt:
-    if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)
-      Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val);
+    if (VA.getLocVT().isInteger() &&
+        (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16))
+      Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val);
     else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
       Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val);
     else
@@ -15176,7 +15230,8 @@ static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val,
       Val = convertToScalableVector(LocVT, Val, DAG, Subtarget);
     break;
   case CCValAssign::BCvt:
-    if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)
+    if (VA.getLocVT().isInteger() &&
+        (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16))
       Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val);
     else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
       Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val);
@@ -16196,6 +16251,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FCVT_WU_RV64)
   NODE_NAME_CASE(STRICT_FCVT_W_RV64)
   NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
+  NODE_NAME_CASE(FP_ROUND_BF16)
+  NODE_NAME_CASE(FP_EXTEND_BF16)
   NODE_NAME_CASE(FROUND)
   NODE_NAME_CASE(FPCLASS)
   NODE_NAME_CASE(READ_CYCLE_WIDE)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index a6c7100ddf42b7..ec90e3c0cdcdde 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -111,6 +111,9 @@ enum NodeType : unsigned {
   FCVT_W_RV64,
   FCVT_WU_RV64,
 
+  FP_ROUND_BF16,
+  FP_EXTEND_BF16,
+
   // Rounds an FP value to its corresponding integer in the same FP format.
   // First operand is the value to round, the second operand is the largest
   // integer that can be represented exactly in the FP format. This will be

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
index 1f423591d3dde8..35f9f03f61a13f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
@@ -13,6 +13,20 @@
 //
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// RISC-V specific DAG Nodes.
+//===----------------------------------------------------------------------===//
+ 
+def SDT_RISCVFP_ROUND_BF16
+    : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>;
+def SDT_RISCVFP_EXTEND_BF16
+    : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>;
+ 
+def riscv_fpround_bf16
+    : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>;
+def riscv_fpextend_bf16
+    : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>;
+
 //===----------------------------------------------------------------------===//
 // Instructions
 //===----------------------------------------------------------------------===//
@@ -23,3 +37,27 @@ def FCVT_BF16_S : FPUnaryOp_r_frm<0b0100010, 0b01000, FPR16, FPR32, "fcvt.bf16.s
 def FCVT_S_BF16 : FPUnaryOp_r_frm<0b0100000, 0b00110, FPR32, FPR16, "fcvt.s.bf16">,
                   Sched<[WriteFCvtF32ToF16, ReadFCvtF32ToF16]>;
 } // Predicates = [HasStdExtZfbfmin]
+
+//===----------------------------------------------------------------------===//
+// Pseudo-instructions and codegen patterns
+//===----------------------------------------------------------------------===//
+ 
+let Predicates = [HasStdExtZfbfmin] in {
+/// Loads
+def : LdPat<load, FLH, bf16>;
+
+/// Stores
+def : StPat<store, FSH, FPR16, bf16>;
+
+/// Float conversion operations
+// f32 -> bf16, bf16 -> f32
+def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)), 
+          (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>;
+def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)), 
+          (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>;
+
+// Moves (no conversion)
+def : Pat<(bf16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>;
+def : Pat<(riscv_fmv_x_anyexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
+def : Pat<(riscv_fmv_x_signexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
+} // Predicates = [HasStdExtZfbfmin]

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
index 3ea338d9ed20dd..5dc02e5fa9f9e9 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
@@ -16,9 +16,9 @@
 //===----------------------------------------------------------------------===//
 
 def SDT_RISCVFMV_H_X
-    : SDTypeProfile<1, 1, [SDTCisVT<0, f16>, SDTCisVT<1, XLenVT>]>;
+    : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, XLenVT>]>;
 def SDT_RISCVFMV_X_EXTH
-    : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, f16>]>;
+    : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisFP<1>]>;
 
 def riscv_fmv_h_x
     : SDNode<"RISCVISD::FMV_H_X", SDT_RISCVFMV_H_X>;
@@ -438,7 +438,7 @@ def : Pat<(f16 (any_fpround FPR32:$rs1)), (FCVT_H_S FPR32:$rs1, FRM_DYN)>;
 def : Pat<(any_fpextend (f16 FPR16:$rs1)), (FCVT_S_H FPR16:$rs1)>;
 
 // Moves (no conversion)
-def : Pat<(riscv_fmv_h_x GPR:$src), (FMV_H_X GPR:$src)>;
+def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>;
 def : Pat<(riscv_fmv_x_anyexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
 def : Pat<(riscv_fmv_x_signexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
 
@@ -453,7 +453,7 @@ def : Pat<(any_fpround FPR32INX:$rs1), (FCVT_H_S_INX FPR32INX:$rs1, FRM_DYN)>;
 def : Pat<(any_fpextend FPR16INX:$rs1), (FCVT_S_H_INX FPR16INX:$rs1)>;
 
 // Moves (no conversion)
-def : Pat<(riscv_fmv_h_x GPR:$src), (COPY_TO_REGCLASS GPR:$src, GPR)>;
+def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (COPY_TO_REGCLASS GPR:$src, GPR)>;
 def : Pat<(riscv_fmv_x_anyexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>;
 def : Pat<(riscv_fmv_x_signexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>;
 

diff  --git a/llvm/test/CodeGen/RISCV/zfbfmin.ll b/llvm/test/CodeGen/RISCV/zfbfmin.ll
new file mode 100644
index 00000000000000..b32e6dc0b14b5c
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/zfbfmin.ll
@@ -0,0 +1,92 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \
+; RUN:   -target-abi ilp32d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \
+; RUN:   -target-abi lp64d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s
+
+define bfloat @bitcast_bf16_i16(i16 %a) nounwind {
+; CHECKIZFBFMIN-LABEL: bitcast_bf16_i16:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fmv.h.x fa0, a0
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = bitcast i16 %a to bfloat
+  ret bfloat %1
+}
+
+define i16 @bitcast_i16_bf16(bfloat %a) nounwind {
+; CHECKIZFBFMIN-LABEL: bitcast_i16_bf16:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fmv.x.h a0, fa0
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = bitcast bfloat %a to i16
+  ret i16 %1
+}
+
+define bfloat @fcvt_bf16_s(float %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_bf16_s:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa0
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = fptrunc float %a to bfloat
+  ret bfloat %1
+}
+
+define float @fcvt_s_bf16(bfloat %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_s_bf16:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fcvt.s.bf16 fa0, fa0
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = fpext bfloat %a to float
+  ret float %1
+}
+
+define bfloat @fcvt_bf16_d(double %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_bf16_d:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fcvt.s.d fa5, fa0
+; CHECKIZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = fptrunc double %a to bfloat
+  ret bfloat %1
+}
+
+define double @fcvt_d_bf16(bfloat %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_d_bf16:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa0
+; CHECKIZFBFMIN-NEXT:    fcvt.d.s fa0, fa5
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = fpext bfloat %a to double
+  ret double %1
+}
+
+define bfloat @bfloat_load(ptr %a) nounwind {
+; CHECKIZFBFMIN-LABEL: bfloat_load:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    flh fa0, 6(a0)
+; CHECKIZFBFMIN-NEXT:    ret
+  %1 = getelementptr bfloat, ptr %a, i32 3
+  %2 = load bfloat, ptr %1
+  ret bfloat %2
+}
+
+define bfloat @bfloat_imm() nounwind {
+; CHECKIZFBFMIN-LABEL: bfloat_imm:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    lui a0, %hi(.LCPI7_0)
+; CHECKIZFBFMIN-NEXT:    flh fa0, %lo(.LCPI7_0)(a0)
+; CHECKIZFBFMIN-NEXT:    ret
+  ret bfloat 3.0
+}
+
+define dso_local void @bfloat_store(ptr %a, bfloat %b) nounwind {
+; CHECKIZFBFMIN-LABEL: bfloat_store:
+; CHECKIZFBFMIN:       # %bb.0:
+; CHECKIZFBFMIN-NEXT:    fsh fa0, 0(a0)
+; CHECKIZFBFMIN-NEXT:    fsh fa0, 16(a0)
+; CHECKIZFBFMIN-NEXT:    ret
+  store bfloat %b, ptr %a
+  %1 = getelementptr bfloat, ptr %a, i32 8
+  store bfloat %b, ptr %1
+  ret void
+}


        


More information about the llvm-commits mailing list