[llvm] [DAGCombine] Add DAG optimisation for BF16_TO_FP (PR #69426)

Shao-Ce SUN via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 26 01:16:45 PST 2023


https://github.com/sunshaoce updated https://github.com/llvm/llvm-project/pull/69426

>From 035d1a256d411e2669a1c012cc02cb2a16b9258f Mon Sep 17 00:00:00 2001
From: Shao-Ce SUN <sunshaoce at gmail.com>
Date: Wed, 18 Oct 2023 15:22:17 +0800
Subject: [PATCH 1/4] [DAGCombine] Add DAG optimisation for BF16_TO_FP

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 19 +++++++++
 llvm/test/CodeGen/RISCV/bfloat-convert.ll     | 42 -------------------
 llvm/test/CodeGen/RISCV/bfloat.ll             | 32 ++------------
 3 files changed, 23 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c2d5b40ee72510..fd172141b2bcb6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -546,6 +546,7 @@ namespace {
     SDValue visitFP_TO_FP16(SDNode *N);
     SDValue visitFP16_TO_FP(SDNode *N);
     SDValue visitFP_TO_BF16(SDNode *N);
+    SDValue visitBF16_TO_FP(SDNode *N);
     SDValue visitVECREDUCE(SDNode *N);
     SDValue visitVPOp(SDNode *N);
     SDValue visitGET_FPENV_MEM(SDNode *N);
@@ -1914,6 +1915,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
 SDValue DAGCombiner::visit(SDNode *N) {
   // clang-format off
   switch (N->getOpcode()) {
+    // clang-format off
   default: break;
   case ISD::TokenFactor:        return visitTokenFactor(N);
   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
@@ -2047,6 +2049,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
+  case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
   case ISD::FREEZE:             return visitFREEZE(N);
   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
@@ -2068,6 +2071,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
 #include "llvm/IR/VPIntrinsics.def"
     return visitVPOp(N);
+    // clang-format on
   }
   // clang-format on
   return SDValue();
@@ -26161,6 +26165,21 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+
+  // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
+  if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
+    ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
+    if (AndConst && AndConst->getAPIntValue() == 0xffff) {
+      return DAG.getNode(ISD::BF16_TO_FP, SDLoc(N), N->getValueType(0),
+                         N0.getOperand(0));
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N0.getValueType();
diff --git a/llvm/test/CodeGen/RISCV/bfloat-convert.ll b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
index 8a0c4240d161bf..bfa2c3bb4a8ba6 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-convert.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
@@ -39,8 +39,6 @@ define i16 @fcvt_si_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_si_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -100,8 +98,6 @@ define i16 @fcvt_si_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_si_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    feq.s a0, fa5, fa5
@@ -145,8 +141,6 @@ define i16 @fcvt_ui_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_ui_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -196,8 +190,6 @@ define i16 @fcvt_ui_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-NEXT:    lui a0, %hi(.LCPI3_0)
 ; RV64ID-NEXT:    flw fa5, %lo(.LCPI3_0)(a0)
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa4, a0
 ; RV64ID-NEXT:    fmv.w.x fa3, zero
@@ -235,8 +227,6 @@ define i32 @fcvt_w_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -281,8 +271,6 @@ define i32 @fcvt_w_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.w.s a0, fa5, rtz
@@ -321,8 +309,6 @@ define i32 @fcvt_wu_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -361,8 +347,6 @@ define i32 @fcvt_wu_bf16_multiple_use(bfloat %x, ptr %y) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16_multiple_use:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -413,8 +397,6 @@ define i32 @fcvt_wu_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.wu.s a0, fa5, rtz
@@ -463,8 +445,6 @@ define i64 @fcvt_l_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_l_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -606,8 +586,6 @@ define i64 @fcvt_l_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_l_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -654,8 +632,6 @@ define i64 @fcvt_lu_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_lu_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -730,8 +706,6 @@ define i64 @fcvt_lu_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_lu_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -1200,8 +1174,6 @@ define float @fcvt_s_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_s_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa0, a0
 ; RV64ID-NEXT:    ret
@@ -1313,8 +1285,6 @@ define double @fcvt_d_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_d_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.d.s fa0, fa5
@@ -1521,8 +1491,6 @@ define signext i8 @fcvt_w_s_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_s_i8:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -1582,8 +1550,6 @@ define signext i8 @fcvt_w_s_sat_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_s_sat_i8:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    feq.s a0, fa5, fa5
@@ -1627,8 +1593,6 @@ define zeroext i8 @fcvt_wu_s_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_s_i8:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -1676,8 +1640,6 @@ define zeroext i8 @fcvt_wu_s_sat_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_s_sat_i8:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fmv.w.x fa4, zero
@@ -1731,8 +1693,6 @@ define zeroext i32 @fcvt_wu_bf16_sat_zext(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16_sat_zext:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.wu.s a0, fa5, rtz
@@ -1784,8 +1744,6 @@ define signext i32 @fcvt_w_bf16_sat_sext(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_bf16_sat_sext:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.w.s a0, fa5, rtz
diff --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll
index 5013f76f9b0b33..d62f35388123f7 100644
--- a/llvm/test/CodeGen/RISCV/bfloat.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat.ll
@@ -164,8 +164,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
 ;
 ; RV64ID-LP64-LABEL: bfloat_to_float:
 ; RV64ID-LP64:       # %bb.0:
-; RV64ID-LP64-NEXT:    slli a0, a0, 48
-; RV64ID-LP64-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64-NEXT:    ret
 ;
@@ -179,8 +177,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
 ; RV64ID-LP64D-LABEL: bfloat_to_float:
 ; RV64ID-LP64D:       # %bb.0:
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    slli a0, a0, 48
-; RV64ID-LP64D-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa0, a0
 ; RV64ID-LP64D-NEXT:    ret
@@ -223,8 +219,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
 ;
 ; RV64ID-LP64-LABEL: bfloat_to_double:
 ; RV64ID-LP64:       # %bb.0:
-; RV64ID-LP64-NEXT:    slli a0, a0, 48
-; RV64ID-LP64-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-LP64-NEXT:    fcvt.d.s fa5, fa5
@@ -242,8 +236,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
 ; RV64ID-LP64D-LABEL: bfloat_to_double:
 ; RV64ID-LP64D:       # %bb.0:
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    slli a0, a0, 48
-; RV64ID-LP64D-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-LP64D-NEXT:    fcvt.d.s fa0, fa5
@@ -366,10 +358,6 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
 ; RV64ID-LP64:       # %bb.0:
 ; RV64ID-LP64-NEXT:    addi sp, sp, -16
 ; RV64ID-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-LP64-NEXT:    lui a2, 16
-; RV64ID-LP64-NEXT:    addi a2, a2, -1
-; RV64ID-LP64-NEXT:    and a0, a0, a2
-; RV64ID-LP64-NEXT:    and a1, a1, a2
 ; RV64ID-LP64-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64-NEXT:    fmv.w.x fa5, a1
 ; RV64ID-LP64-NEXT:    slli a0, a0, 16
@@ -408,11 +396,7 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
 ; RV64ID-LP64D-NEXT:    addi sp, sp, -16
 ; RV64ID-LP64D-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    lui a1, 16
-; RV64ID-LP64D-NEXT:    addi a1, a1, -1
-; RV64ID-LP64D-NEXT:    and a0, a0, a1
-; RV64ID-LP64D-NEXT:    fmv.x.w a2, fa1
-; RV64ID-LP64D-NEXT:    and a1, a2, a1
+; RV64ID-LP64D-NEXT:    fmv.x.w a1, fa1
 ; RV64ID-LP64D-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa5, a1
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16
@@ -604,12 +588,8 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
 ; RV64ID-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
 ; RV64ID-LP64-NEXT:    sd s0, 0(sp) # 8-byte Folded Spill
 ; RV64ID-LP64-NEXT:    mv s0, a0
-; RV64ID-LP64-NEXT:    lui a0, 16
-; RV64ID-LP64-NEXT:    addi a0, a0, -1
-; RV64ID-LP64-NEXT:    and a1, a1, a0
-; RV64ID-LP64-NEXT:    and a0, a2, a0
-; RV64ID-LP64-NEXT:    slli a0, a0, 16
-; RV64ID-LP64-NEXT:    fmv.w.x fa5, a0
+; RV64ID-LP64-NEXT:    slli a2, a2, 16
+; RV64ID-LP64-NEXT:    fmv.w.x fa5, a2
 ; RV64ID-LP64-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64-NEXT:    fmv.w.x fa4, a1
 ; RV64ID-LP64-NEXT:    fadd.s fa5, fa4, fa5
@@ -651,11 +631,7 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
 ; RV64ID-LP64D-NEXT:    sd s0, 0(sp) # 8-byte Folded Spill
 ; RV64ID-LP64D-NEXT:    mv s0, a0
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    lui a1, 16
-; RV64ID-LP64D-NEXT:    addi a1, a1, -1
-; RV64ID-LP64D-NEXT:    and a0, a0, a1
-; RV64ID-LP64D-NEXT:    fmv.x.w a2, fa1
-; RV64ID-LP64D-NEXT:    and a1, a2, a1
+; RV64ID-LP64D-NEXT:    fmv.x.w a1, fa1
 ; RV64ID-LP64D-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa5, a1
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16

>From 7e837df43f4ca10ec707b4b85d57d8571a21c402 Mon Sep 17 00:00:00 2001
From: Shao-Ce SUN <sunshaoce at gmail.com>
Date: Wed, 18 Oct 2023 20:05:09 +0800
Subject: [PATCH 2/4] fixup! remove visitBF16_TO_FP

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 20 ++-----------------
 1 file changed, 2 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index fd172141b2bcb6..7ead8735896123 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -546,7 +546,6 @@ namespace {
     SDValue visitFP_TO_FP16(SDNode *N);
     SDValue visitFP16_TO_FP(SDNode *N);
     SDValue visitFP_TO_BF16(SDNode *N);
-    SDValue visitBF16_TO_FP(SDNode *N);
     SDValue visitVECREDUCE(SDNode *N);
     SDValue visitVPOp(SDNode *N);
     SDValue visitGET_FPENV_MEM(SDNode *N);
@@ -2049,7 +2048,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
-  case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
+  case ISD::BF16_TO_FP:         return visitFP16_TO_FP(N);
   case ISD::FREEZE:             return visitFREEZE(N);
   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
@@ -26147,7 +26146,7 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
-      return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
+      return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
                          N0.getOperand(0));
     }
   }
@@ -26165,21 +26164,6 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
   return SDValue();
 }
 
-SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
-  SDValue N0 = N->getOperand(0);
-
-  // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
-  if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
-    ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
-    if (AndConst && AndConst->getAPIntValue() == 0xffff) {
-      return DAG.getNode(ISD::BF16_TO_FP, SDLoc(N), N->getValueType(0),
-                         N0.getOperand(0));
-    }
-  }
-
-  return SDValue();
-}
-
 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N0.getValueType();

>From d4038bf65fc99a06dd9b23c617fe6f80aa901fec Mon Sep 17 00:00:00 2001
From: Shao-Ce SUN <sunshaoce at outlook.com>
Date: Thu, 16 Nov 2023 19:05:13 +0800
Subject: [PATCH 3/4] fixup! resume visitBF16_TO_FP

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7ead8735896123..43b550491dc36f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -546,6 +546,7 @@ namespace {
     SDValue visitFP_TO_FP16(SDNode *N);
     SDValue visitFP16_TO_FP(SDNode *N);
     SDValue visitFP_TO_BF16(SDNode *N);
+    SDValue visitBF16_TO_FP(SDNode *N);
     SDValue visitVECREDUCE(SDNode *N);
     SDValue visitVPOp(SDNode *N);
     SDValue visitGET_FPENV_MEM(SDNode *N);
@@ -2048,7 +2049,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
-  case ISD::BF16_TO_FP:         return visitFP16_TO_FP(N);
+  case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
   case ISD::FREEZE:             return visitFREEZE(N);
   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
@@ -26164,6 +26165,8 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) { return visitFP16_TO_FP(N); }
+
 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N0.getValueType();

>From d91839a5bb8eacb370862cdf0ee5988aab7223b7 Mon Sep 17 00:00:00 2001
From: Shao-Ce SUN <sunshaoce at outlook.com>
Date: Tue, 26 Dec 2023 17:14:07 +0800
Subject: [PATCH 4/4] fixup! add assert for visitFP16_TO_FP

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 43b550491dc36f..074b27d8c10a6c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -26141,14 +26141,16 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
 }
 
 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
+  auto Op = N->getOpcode();
+  assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
+         "opcode should be FP16_TO_FP or BF16_TO_FP.");
   SDValue N0 = N->getOperand(0);
 
   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
-      return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
-                         N0.getOperand(0));
+      return DAG.getNode(Op, SDLoc(N), N->getValueType(0), N0.getOperand(0));
     }
   }
 



More information about the llvm-commits mailing list