[llvm] f375ee3 - [RISCV] Add codegen for Zfbfmin instructions
Jun Sha via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 23 19:38:24 PDT 2023
Author: Jun Sha (Joshua)
Date: 2023-07-24T10:37:58+08:00
New Revision: f375ee36c4e1c87b8ed47ef44d3613bd5756f57a
URL: https://github.com/llvm/llvm-project/commit/f375ee36c4e1c87b8ed47ef44d3613bd5756f57a
DIFF: https://github.com/llvm/llvm-project/commit/f375ee36c4e1c87b8ed47ef44d3613bd5756f57a.diff
LOG: [RISCV] Add codegen for Zfbfmin instructions
The implementation in https://reviews.llvm.org/D151313 is done for the circumstance without Zfbfmin. This patch adds codegen support for the 6 instructions provided in Zfbfmin extension.
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D153234
Added:
llvm/test/CodeGen/RISCV/zfbfmin.ll
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 67a8ac5b6ee767..f642124e072ccf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -116,6 +116,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtZfhOrZfhmin())
addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
+ if (Subtarget.hasStdExtZfbfmin())
+ addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass);
if (Subtarget.hasStdExtF())
addRegisterClass(MVT::f32, &RISCV::FPR32RegClass);
if (Subtarget.hasStdExtD())
@@ -359,6 +361,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+
+ if (Subtarget.hasStdExtZfbfmin()) {
+ setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+ setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+ setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+ setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
+ setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
+ setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
+ }
if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
if (Subtarget.hasStdExtZfhOrZhinx()) {
@@ -4768,6 +4779,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0);
return FPConv;
}
+ if (VT == MVT::bf16 && Op0VT == MVT::i16 &&
+ Subtarget.hasStdExtZfbfmin()) {
+ SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
+ SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0);
+ return FPConv;
+ }
if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() &&
Subtarget.hasStdExtFOrZfinx()) {
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0);
@@ -4931,11 +4948,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
}
return SDValue();
}
- case ISD::FP_EXTEND:
- case ISD::FP_ROUND:
+ case ISD::FP_EXTEND: {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ SDValue Op0 = Op.getOperand(0);
+ EVT Op0VT = Op0.getValueType();
+ if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin())
+ return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
+ if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) {
+ SDValue FloatVal =
+ DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
+ return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal);
+ }
+
+ if (!Op.getValueType().isVector())
+ return Op;
+ return lowerVectorFPExtendOrRoundLike(Op, DAG);
+ }
+ case ISD::FP_ROUND: {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ SDValue Op0 = Op.getOperand(0);
+ EVT Op0VT = Op0.getValueType();
+ if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin())
+ return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0);
+ if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() &&
+ Subtarget.hasStdExtDOrZdinx()) {
+ SDValue FloatVal =
+ DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+ return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal);
+ }
+
if (!Op.getValueType().isVector())
return Op;
return lowerVectorFPExtendOrRoundLike(Op, DAG);
+ }
case ISD::STRICT_FP_ROUND:
case ISD::STRICT_FP_EXTEND:
return lowerStrictFPExtendOrRoundLike(Op, DAG);
@@ -9926,6 +9974,10 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
+ } else if (VT == MVT::i16 && Op0VT == MVT::bf16 &&
+ Subtarget.hasStdExtZfbfmin()) {
+ SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
+ Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
} else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() &&
Subtarget.hasStdExtFOrZfinx()) {
SDValue FPConv =
@@ -14867,7 +14919,8 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
// similar local variables rather than directly checking against the target
// ABI.
- if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) {
+ if (UseGPRForF16_F32 &&
+ (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) {
LocVT = XLenVT;
LocInfo = CCValAssign::BCvt;
} else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) {
@@ -14960,7 +15013,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
unsigned StoreSizeBytes = XLen / 8;
Align StackAlign = Align(XLen / 8);
- if (ValVT == MVT::f16 && !UseGPRForF16_F32)
+ if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32)
Reg = State.AllocateReg(ArgFPR16s);
else if (ValVT == MVT::f32 && !UseGPRForF16_F32)
Reg = State.AllocateReg(ArgFPR32s);
@@ -15117,8 +15170,9 @@ static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val,
Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget);
break;
case CCValAssign::BCvt:
- if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)
- Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val);
+ if (VA.getLocVT().isInteger() &&
+ (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16))
+ Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val);
else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val);
else
@@ -15176,7 +15230,8 @@ static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val,
Val = convertToScalableVector(LocVT, Val, DAG, Subtarget);
break;
case CCValAssign::BCvt:
- if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)
+ if (VA.getLocVT().isInteger() &&
+ (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16))
Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val);
else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val);
@@ -16196,6 +16251,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(FCVT_WU_RV64)
NODE_NAME_CASE(STRICT_FCVT_W_RV64)
NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
+ NODE_NAME_CASE(FP_ROUND_BF16)
+ NODE_NAME_CASE(FP_EXTEND_BF16)
NODE_NAME_CASE(FROUND)
NODE_NAME_CASE(FPCLASS)
NODE_NAME_CASE(READ_CYCLE_WIDE)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index a6c7100ddf42b7..ec90e3c0cdcdde 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -111,6 +111,9 @@ enum NodeType : unsigned {
FCVT_W_RV64,
FCVT_WU_RV64,
+ FP_ROUND_BF16,
+ FP_EXTEND_BF16,
+
// Rounds an FP value to its corresponding integer in the same FP format.
// First operand is the value to round, the second operand is the largest
// integer that can be represented exactly in the FP format. This will be
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
index 1f423591d3dde8..35f9f03f61a13f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
@@ -13,6 +13,20 @@
//
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// RISC-V specific DAG Nodes.
+//===----------------------------------------------------------------------===//
+
+def SDT_RISCVFP_ROUND_BF16
+ : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>;
+def SDT_RISCVFP_EXTEND_BF16
+ : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>;
+
+def riscv_fpround_bf16
+ : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>;
+def riscv_fpextend_bf16
+ : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>;
+
//===----------------------------------------------------------------------===//
// Instructions
//===----------------------------------------------------------------------===//
@@ -23,3 +37,27 @@ def FCVT_BF16_S : FPUnaryOp_r_frm<0b0100010, 0b01000, FPR16, FPR32, "fcvt.bf16.s
def FCVT_S_BF16 : FPUnaryOp_r_frm<0b0100000, 0b00110, FPR32, FPR16, "fcvt.s.bf16">,
Sched<[WriteFCvtF32ToF16, ReadFCvtF32ToF16]>;
} // Predicates = [HasStdExtZfbfmin]
+
+//===----------------------------------------------------------------------===//
+// Pseudo-instructions and codegen patterns
+//===----------------------------------------------------------------------===//
+
+let Predicates = [HasStdExtZfbfmin] in {
+/// Loads
+def : LdPat<load, FLH, bf16>;
+
+/// Stores
+def : StPat<store, FSH, FPR16, bf16>;
+
+/// Float conversion operations
+// f32 -> bf16, bf16 -> f32
+def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)),
+ (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>;
+def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)),
+ (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>;
+
+// Moves (no conversion)
+def : Pat<(bf16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>;
+def : Pat<(riscv_fmv_x_anyexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
+def : Pat<(riscv_fmv_x_signexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
+} // Predicates = [HasStdExtZfbfmin]
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
index 3ea338d9ed20dd..5dc02e5fa9f9e9 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
@@ -16,9 +16,9 @@
//===----------------------------------------------------------------------===//
def SDT_RISCVFMV_H_X
- : SDTypeProfile<1, 1, [SDTCisVT<0, f16>, SDTCisVT<1, XLenVT>]>;
+ : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, XLenVT>]>;
def SDT_RISCVFMV_X_EXTH
- : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, f16>]>;
+ : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisFP<1>]>;
def riscv_fmv_h_x
: SDNode<"RISCVISD::FMV_H_X", SDT_RISCVFMV_H_X>;
@@ -438,7 +438,7 @@ def : Pat<(f16 (any_fpround FPR32:$rs1)), (FCVT_H_S FPR32:$rs1, FRM_DYN)>;
def : Pat<(any_fpextend (f16 FPR16:$rs1)), (FCVT_S_H FPR16:$rs1)>;
// Moves (no conversion)
-def : Pat<(riscv_fmv_h_x GPR:$src), (FMV_H_X GPR:$src)>;
+def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>;
def : Pat<(riscv_fmv_x_anyexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
def : Pat<(riscv_fmv_x_signexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
@@ -453,7 +453,7 @@ def : Pat<(any_fpround FPR32INX:$rs1), (FCVT_H_S_INX FPR32INX:$rs1, FRM_DYN)>;
def : Pat<(any_fpextend FPR16INX:$rs1), (FCVT_S_H_INX FPR16INX:$rs1)>;
// Moves (no conversion)
-def : Pat<(riscv_fmv_h_x GPR:$src), (COPY_TO_REGCLASS GPR:$src, GPR)>;
+def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (COPY_TO_REGCLASS GPR:$src, GPR)>;
def : Pat<(riscv_fmv_x_anyexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>;
def : Pat<(riscv_fmv_x_signexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>;
diff --git a/llvm/test/CodeGen/RISCV/zfbfmin.ll b/llvm/test/CodeGen/RISCV/zfbfmin.ll
new file mode 100644
index 00000000000000..b32e6dc0b14b5c
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/zfbfmin.ll
@@ -0,0 +1,92 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \
+; RUN: -target-abi ilp32d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \
+; RUN: -target-abi lp64d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s
+
+define bfloat @bitcast_bf16_i16(i16 %a) nounwind {
+; CHECKIZFBFMIN-LABEL: bitcast_bf16_i16:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fmv.h.x fa0, a0
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = bitcast i16 %a to bfloat
+ ret bfloat %1
+}
+
+define i16 @bitcast_i16_bf16(bfloat %a) nounwind {
+; CHECKIZFBFMIN-LABEL: bitcast_i16_bf16:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fmv.x.h a0, fa0
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = bitcast bfloat %a to i16
+ ret i16 %1
+}
+
+define bfloat @fcvt_bf16_s(float %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_bf16_s:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa0
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = fptrunc float %a to bfloat
+ ret bfloat %1
+}
+
+define float @fcvt_s_bf16(bfloat %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_s_bf16:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa0, fa0
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = fpext bfloat %a to float
+ ret float %1
+}
+
+define bfloat @fcvt_bf16_d(double %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_bf16_d:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fcvt.s.d fa5, fa0
+; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = fptrunc double %a to bfloat
+ ret bfloat %1
+}
+
+define double @fcvt_d_bf16(bfloat %a) nounwind {
+; CHECKIZFBFMIN-LABEL: fcvt_d_bf16:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0
+; CHECKIZFBFMIN-NEXT: fcvt.d.s fa0, fa5
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = fpext bfloat %a to double
+ ret double %1
+}
+
+define bfloat @bfloat_load(ptr %a) nounwind {
+; CHECKIZFBFMIN-LABEL: bfloat_load:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: flh fa0, 6(a0)
+; CHECKIZFBFMIN-NEXT: ret
+ %1 = getelementptr bfloat, ptr %a, i32 3
+ %2 = load bfloat, ptr %1
+ ret bfloat %2
+}
+
+define bfloat @bfloat_imm() nounwind {
+; CHECKIZFBFMIN-LABEL: bfloat_imm:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: lui a0, %hi(.LCPI7_0)
+; CHECKIZFBFMIN-NEXT: flh fa0, %lo(.LCPI7_0)(a0)
+; CHECKIZFBFMIN-NEXT: ret
+ ret bfloat 3.0
+}
+
+define dso_local void @bfloat_store(ptr %a, bfloat %b) nounwind {
+; CHECKIZFBFMIN-LABEL: bfloat_store:
+; CHECKIZFBFMIN: # %bb.0:
+; CHECKIZFBFMIN-NEXT: fsh fa0, 0(a0)
+; CHECKIZFBFMIN-NEXT: fsh fa0, 16(a0)
+; CHECKIZFBFMIN-NEXT: ret
+ store bfloat %b, ptr %a
+ %1 = getelementptr bfloat, ptr %a, i32 8
+ store bfloat %b, ptr %1
+ ret void
+}
More information about the llvm-commits
mailing list