[llvm] [DAGCombine] Add DAG optimisation for BF16_TO_FP (PR #69426)
Shao-Ce SUN via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 18 00:33:16 PDT 2023
https://github.com/sunshaoce created https://github.com/llvm/llvm-project/pull/69426
Before
```
slli a0, a0, 48
srli a0, a0, 48
slli a0, a0, 16
```
After
```
slli a0, a0, 16
```
>From ceb92792f3bdb521f97bc35e57b253475e31321d Mon Sep 17 00:00:00 2001
From: Shao-Ce SUN <sunshaoce at gmail.com>
Date: Wed, 18 Oct 2023 15:22:17 +0800
Subject: [PATCH] [DAGCombine] Add DAG optimisation for BF16_TO_FP
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 19 +++++++++
llvm/test/CodeGen/RISCV/bfloat-convert.ll | 42 -------------------
llvm/test/CodeGen/RISCV/bfloat.ll | 32 ++------------
3 files changed, 23 insertions(+), 70 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2dfdddad3cc389f..1e9d2176befd93b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
SDValue visitFP_TO_FP16(SDNode *N);
SDValue visitFP16_TO_FP(SDNode *N);
SDValue visitFP_TO_BF16(SDNode *N);
+ SDValue visitBF16_TO_FP(SDNode *N);
SDValue visitVECREDUCE(SDNode *N);
SDValue visitVPOp(SDNode *N);
SDValue visitGET_FPENV_MEM(SDNode *N);
@@ -1912,6 +1913,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
SDValue DAGCombiner::visit(SDNode *N) {
switch (N->getOpcode()) {
+ // clang-format off
default: break;
case ISD::TokenFactor: return visitTokenFactor(N);
case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
@@ -2043,6 +2045,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
+ case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
case ISD::FREEZE: return visitFREEZE(N);
case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
@@ -2064,6 +2067,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
#include "llvm/IR/VPIntrinsics.def"
return visitVPOp(N);
+ // clang-format on
}
return SDValue();
}
@@ -26219,6 +26223,21 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+
+ // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
+ if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
+ ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
+ if (AndConst && AndConst->getAPIntValue() == 0xffff) {
+ return DAG.getNode(ISD::BF16_TO_FP, SDLoc(N), N->getValueType(0),
+ N0.getOperand(0));
+ }
+ }
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N0.getValueType();
diff --git a/llvm/test/CodeGen/RISCV/bfloat-convert.ll b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
index 8a0c4240d161bfb..bfa2c3bb4a8ba66 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-convert.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
@@ -39,8 +39,6 @@ define i16 @fcvt_si_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_si_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -100,8 +98,6 @@ define i16 @fcvt_si_bf16_sat(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_si_bf16_sat:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: feq.s a0, fa5, fa5
@@ -145,8 +141,6 @@ define i16 @fcvt_ui_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_ui_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -196,8 +190,6 @@ define i16 @fcvt_ui_bf16_sat(bfloat %a) nounwind {
; RV64ID-NEXT: lui a0, %hi(.LCPI3_0)
; RV64ID-NEXT: flw fa5, %lo(.LCPI3_0)(a0)
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa4, a0
; RV64ID-NEXT: fmv.w.x fa3, zero
@@ -235,8 +227,6 @@ define i32 @fcvt_w_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_w_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -281,8 +271,6 @@ define i32 @fcvt_w_bf16_sat(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_w_bf16_sat:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.w.s a0, fa5, rtz
@@ -321,8 +309,6 @@ define i32 @fcvt_wu_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_wu_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -361,8 +347,6 @@ define i32 @fcvt_wu_bf16_multiple_use(bfloat %x, ptr %y) nounwind {
; RV64ID-LABEL: fcvt_wu_bf16_multiple_use:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -413,8 +397,6 @@ define i32 @fcvt_wu_bf16_sat(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_wu_bf16_sat:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.wu.s a0, fa5, rtz
@@ -463,8 +445,6 @@ define i64 @fcvt_l_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_l_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -606,8 +586,6 @@ define i64 @fcvt_l_bf16_sat(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_l_bf16_sat:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -654,8 +632,6 @@ define i64 @fcvt_lu_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_lu_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -730,8 +706,6 @@ define i64 @fcvt_lu_bf16_sat(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_lu_bf16_sat:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -1200,8 +1174,6 @@ define float @fcvt_s_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_s_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa0, a0
; RV64ID-NEXT: ret
@@ -1313,8 +1285,6 @@ define double @fcvt_d_bf16(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_d_bf16:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.d.s fa0, fa5
@@ -1521,8 +1491,6 @@ define signext i8 @fcvt_w_s_i8(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_w_s_i8:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -1582,8 +1550,6 @@ define signext i8 @fcvt_w_s_sat_i8(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_w_s_sat_i8:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: feq.s a0, fa5, fa5
@@ -1627,8 +1593,6 @@ define zeroext i8 @fcvt_wu_s_i8(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_wu_s_i8:
; RV64ID: # %bb.0:
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -1676,8 +1640,6 @@ define zeroext i8 @fcvt_wu_s_sat_i8(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_wu_s_sat_i8:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fmv.w.x fa4, zero
@@ -1731,8 +1693,6 @@ define zeroext i32 @fcvt_wu_bf16_sat_zext(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_wu_bf16_sat_zext:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.wu.s a0, fa5, rtz
@@ -1784,8 +1744,6 @@ define signext i32 @fcvt_w_bf16_sat_sext(bfloat %a) nounwind {
; RV64ID-LABEL: fcvt_w_bf16_sat_sext:
; RV64ID: # %bb.0: # %start
; RV64ID-NEXT: fmv.x.w a0, fa0
-; RV64ID-NEXT: slli a0, a0, 48
-; RV64ID-NEXT: srli a0, a0, 48
; RV64ID-NEXT: slli a0, a0, 16
; RV64ID-NEXT: fmv.w.x fa5, a0
; RV64ID-NEXT: fcvt.w.s a0, fa5, rtz
diff --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll
index 5013f76f9b0b33a..d62f35388123f7c 100644
--- a/llvm/test/CodeGen/RISCV/bfloat.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat.ll
@@ -164,8 +164,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
;
; RV64ID-LP64-LABEL: bfloat_to_float:
; RV64ID-LP64: # %bb.0:
-; RV64ID-LP64-NEXT: slli a0, a0, 48
-; RV64ID-LP64-NEXT: srli a0, a0, 48
; RV64ID-LP64-NEXT: slli a0, a0, 16
; RV64ID-LP64-NEXT: ret
;
@@ -179,8 +177,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
; RV64ID-LP64D-LABEL: bfloat_to_float:
; RV64ID-LP64D: # %bb.0:
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT: slli a0, a0, 48
-; RV64ID-LP64D-NEXT: srli a0, a0, 48
; RV64ID-LP64D-NEXT: slli a0, a0, 16
; RV64ID-LP64D-NEXT: fmv.w.x fa0, a0
; RV64ID-LP64D-NEXT: ret
@@ -223,8 +219,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
;
; RV64ID-LP64-LABEL: bfloat_to_double:
; RV64ID-LP64: # %bb.0:
-; RV64ID-LP64-NEXT: slli a0, a0, 48
-; RV64ID-LP64-NEXT: srli a0, a0, 48
; RV64ID-LP64-NEXT: slli a0, a0, 16
; RV64ID-LP64-NEXT: fmv.w.x fa5, a0
; RV64ID-LP64-NEXT: fcvt.d.s fa5, fa5
@@ -242,8 +236,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
; RV64ID-LP64D-LABEL: bfloat_to_double:
; RV64ID-LP64D: # %bb.0:
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT: slli a0, a0, 48
-; RV64ID-LP64D-NEXT: srli a0, a0, 48
; RV64ID-LP64D-NEXT: slli a0, a0, 16
; RV64ID-LP64D-NEXT: fmv.w.x fa5, a0
; RV64ID-LP64D-NEXT: fcvt.d.s fa0, fa5
@@ -366,10 +358,6 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
; RV64ID-LP64: # %bb.0:
; RV64ID-LP64-NEXT: addi sp, sp, -16
; RV64ID-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-LP64-NEXT: lui a2, 16
-; RV64ID-LP64-NEXT: addi a2, a2, -1
-; RV64ID-LP64-NEXT: and a0, a0, a2
-; RV64ID-LP64-NEXT: and a1, a1, a2
; RV64ID-LP64-NEXT: slli a1, a1, 16
; RV64ID-LP64-NEXT: fmv.w.x fa5, a1
; RV64ID-LP64-NEXT: slli a0, a0, 16
@@ -408,11 +396,7 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
; RV64ID-LP64D-NEXT: addi sp, sp, -16
; RV64ID-LP64D-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT: lui a1, 16
-; RV64ID-LP64D-NEXT: addi a1, a1, -1
-; RV64ID-LP64D-NEXT: and a0, a0, a1
-; RV64ID-LP64D-NEXT: fmv.x.w a2, fa1
-; RV64ID-LP64D-NEXT: and a1, a2, a1
+; RV64ID-LP64D-NEXT: fmv.x.w a1, fa1
; RV64ID-LP64D-NEXT: slli a1, a1, 16
; RV64ID-LP64D-NEXT: fmv.w.x fa5, a1
; RV64ID-LP64D-NEXT: slli a0, a0, 16
@@ -604,12 +588,8 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
; RV64ID-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
; RV64ID-LP64-NEXT: sd s0, 0(sp) # 8-byte Folded Spill
; RV64ID-LP64-NEXT: mv s0, a0
-; RV64ID-LP64-NEXT: lui a0, 16
-; RV64ID-LP64-NEXT: addi a0, a0, -1
-; RV64ID-LP64-NEXT: and a1, a1, a0
-; RV64ID-LP64-NEXT: and a0, a2, a0
-; RV64ID-LP64-NEXT: slli a0, a0, 16
-; RV64ID-LP64-NEXT: fmv.w.x fa5, a0
+; RV64ID-LP64-NEXT: slli a2, a2, 16
+; RV64ID-LP64-NEXT: fmv.w.x fa5, a2
; RV64ID-LP64-NEXT: slli a1, a1, 16
; RV64ID-LP64-NEXT: fmv.w.x fa4, a1
; RV64ID-LP64-NEXT: fadd.s fa5, fa4, fa5
@@ -651,11 +631,7 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
; RV64ID-LP64D-NEXT: sd s0, 0(sp) # 8-byte Folded Spill
; RV64ID-LP64D-NEXT: mv s0, a0
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT: lui a1, 16
-; RV64ID-LP64D-NEXT: addi a1, a1, -1
-; RV64ID-LP64D-NEXT: and a0, a0, a1
-; RV64ID-LP64D-NEXT: fmv.x.w a2, fa1
-; RV64ID-LP64D-NEXT: and a1, a2, a1
+; RV64ID-LP64D-NEXT: fmv.x.w a1, fa1
; RV64ID-LP64D-NEXT: slli a1, a1, 16
; RV64ID-LP64D-NEXT: fmv.w.x fa5, a1
; RV64ID-LP64D-NEXT: slli a0, a0, 16
More information about the llvm-commits
mailing list