[llvm] [RISCV] Fold FP32->BF16->FP32 (PR #69687)

Shao-Ce SUN via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 20 00:53:11 PDT 2023


https://github.com/sunshaoce created https://github.com/llvm/llvm-project/pull/69687

None

>From fd44cce158179bdc34f6a8aacb99f18246048a62 Mon Sep 17 00:00:00 2001
From: sunshaoce <sunshaoce at gmail.com>
Date: Fri, 20 Oct 2023 15:51:43 +0800
Subject: [PATCH] [RISCV] Fold FP32->BF16->FP32

---
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 12 +++-
 llvm/test/CodeGen/RISCV/bfloat-arith.ll     | 76 ++-------------------
 2 files changed, 15 insertions(+), 73 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index cda98c8848b3554..e44d7013ef2a23c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -2127,7 +2127,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     ReplaceNode(Node, Load);
     return;
   }
-  case ISD::PREFETCH:
+  case ISD::PREFETCH: {
     unsigned Locality = Node->getConstantOperandVal(3);
     if (Locality > 2)
       break;
@@ -2158,6 +2158,16 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     }
     break;
   }
+  case RISCVISD::FP_EXTEND_BF16: {
+    SDValue V = Node->getOperand(0);
+    // fold (fp_extend_bf16 (fp_round_bf16 op)) -> op
+    if (V.getOpcode() == RISCVISD::FP_ROUND_BF16) {
+      ReplaceUses(Node, V->getOperand(0).getNode());
+      return;
+    }
+    break;
+  }
+  }
 
   // Select the default instruction.
   SelectCode(Node);
diff --git a/llvm/test/CodeGen/RISCV/bfloat-arith.ll b/llvm/test/CodeGen/RISCV/bfloat-arith.ll
index 98c58ab2ff693ef..a3c759ea7d7ea52 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-arith.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-arith.ll
@@ -109,11 +109,7 @@ define i32 @fneg_s(bfloat %a, bfloat %b) nounwind {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa4, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
 ; CHECK-NEXT:    feq.s a0, fa5, fa4
 ; CHECK-NEXT:    ret
   %1 = fadd bfloat %a, %a
@@ -130,8 +126,6 @@ define bfloat @fsgnjn_s(bfloat %a, bfloat %b) nounwind {
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa0
 ; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
 ; RV32IZFBFMIN-NEXT:    fneg.s fa5, fa5
 ; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
 ; RV32IZFBFMIN-NEXT:    fsh fa0, 8(sp)
@@ -152,8 +146,6 @@ define bfloat @fsgnjn_s(bfloat %a, bfloat %b) nounwind {
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa0
 ; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
 ; RV64IZFBFMIN-NEXT:    fneg.s fa5, fa5
 ; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
 ; RV64IZFBFMIN-NEXT:    fsh fa0, 0(sp)
@@ -181,11 +173,7 @@ define bfloat @fabs_s(bfloat %a, bfloat %b) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa0
 ; CHECK-NEXT:    fadd.s fa5, fa4, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fabs.s fa4, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
 ; CHECK-NEXT:    fadd.s fa5, fa4, fa5
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
@@ -244,11 +232,7 @@ define bfloat @fmsub_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa2
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa1
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
 ; CHECK-NEXT:    fmadd.s fa5, fa3, fa4, fa5
@@ -266,18 +250,10 @@ define bfloat @fnmadd_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa2
 ; CHECK-NEXT:    fadd.s fa4, fa3, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
 ; CHECK-NEXT:    fneg.s fa4, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa1
 ; CHECK-NEXT:    fmadd.s fa5, fa5, fa3, fa4
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
@@ -296,18 +272,10 @@ define bfloat @fnmadd_s_2(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa2
 ; CHECK-NEXT:    fadd.s fa4, fa3, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
 ; CHECK-NEXT:    fneg.s fa4, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
 ; CHECK-NEXT:    fmadd.s fa5, fa3, fa5, fa4
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
@@ -327,8 +295,6 @@ define bfloat @fnmadd_s_3(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa1
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
 ; CHECK-NEXT:    fmadd.s fa5, fa3, fa4, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
@@ -345,8 +311,6 @@ define bfloat @fnmadd_nsz(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa1
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
 ; CHECK-NEXT:    fmadd.s fa5, fa3, fa4, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
@@ -361,11 +325,7 @@ define bfloat @fnmsub_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa2
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa1
 ; CHECK-NEXT:    fmadd.s fa5, fa5, fa3, fa4
@@ -383,11 +343,7 @@ define bfloat @fnmsub_s_2(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa2
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
 ; CHECK-NEXT:    fmadd.s fa5, fa3, fa5, fa4
@@ -404,11 +360,8 @@ define bfloat @fmadd_s_contract(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa0
-; CHECK-NEXT:    fmul.s fa5, fa4, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa2
-; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fcvt.s.bf16 fa3, fa2
+; CHECK-NEXT:    fmadd.s fa5, fa4, fa5, fa3
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
   %1 = fmul contract bfloat %a, %b
@@ -422,13 +375,9 @@ define bfloat @fmsub_s_contract(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa2
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa4, fa1
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
 ; CHECK-NEXT:    fmul.s fa4, fa3, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
 ; CHECK-NEXT:    fsub.s fa5, fa4, fa5
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
@@ -444,22 +393,12 @@ define bfloat @fnmadd_s_contract(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa1
 ; CHECK-NEXT:    fadd.s fa3, fa3, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa3, fa3
 ; CHECK-NEXT:    fcvt.s.bf16 fa2, fa2
 ; CHECK-NEXT:    fadd.s fa4, fa2, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa3, fa3
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fmul.s fa5, fa5, fa3
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
 ; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
 ; CHECK-NEXT:    fsub.s fa5, fa5, fa4
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
@@ -478,17 +417,10 @@ define bfloat @fnmsub_s_contract(bfloat %a, bfloat %b, bfloat %c) nounwind {
 ; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
 ; CHECK-NEXT:    fmv.w.x fa4, zero
 ; CHECK-NEXT:    fadd.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
 ; CHECK-NEXT:    fcvt.s.bf16 fa3, fa1
 ; CHECK-NEXT:    fadd.s fa4, fa3, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
-; CHECK-NEXT:    fmul.s fa5, fa5, fa4
-; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
-; CHECK-NEXT:    fcvt.s.bf16 fa4, fa2
-; CHECK-NEXT:    fsub.s fa5, fa4, fa5
+; CHECK-NEXT:    fcvt.s.bf16 fa3, fa2
+; CHECK-NEXT:    fnmsub.s fa5, fa5, fa4, fa3
 ; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
 ; CHECK-NEXT:    ret
   %a_ = fadd bfloat 0.0, %a ; avoid negation using xor



More information about the llvm-commits mailing list