[llvm] 8c4a07c - [DAGCombiner] Fold fold (fp_to_bf16 (bf16_to_fp op)) -> op

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 15 10:54:57 PDT 2022


Author: Benjamin Kramer
Date: 2022-06-15T19:54:39+02:00
New Revision: 8c4a07c61f0a381ebe273890ed1a5857acefcb2d

URL: https://github.com/llvm/llvm-project/commit/8c4a07c61f0a381ebe273890ed1a5857acefcb2d
DIFF: https://github.com/llvm/llvm-project/commit/8c4a07c61f0a381ebe273890ed1a5857acefcb2d.diff

LOG: [DAGCombiner] Fold fold (fp_to_bf16 (bf16_to_fp op)) -> op

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/bfloat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7e293bad2bf3..f5ab0c2c99a6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -510,6 +510,7 @@ namespace {
     SDValue visitMSCATTER(SDNode *N);
     SDValue visitFP_TO_FP16(SDNode *N);
     SDValue visitFP16_TO_FP(SDNode *N);
+    SDValue visitFP_TO_BF16(SDNode *N);
     SDValue visitVECREDUCE(SDNode *N);
     SDValue visitVPOp(SDNode *N);
 
@@ -1746,6 +1747,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
+  case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
   case ISD::FREEZE:             return visitFREEZE(N);
   case ISD::VECREDUCE_FADD:
   case ISD::VECREDUCE_FMUL:
@@ -23072,6 +23074,16 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+
+  // fold (fp_to_bf16 (bf16_to_fp op)) -> op
+  if (N0->getOpcode() == ISD::BF16_TO_FP)
+    return N0->getOperand(0);
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N0.getValueType();

diff  --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 404ad095ca75..25926857ec7d 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -100,3 +100,16 @@ define void @store_constant(ptr %pc) {
   store bfloat 1.0, ptr %pc
   ret void
 }
+
+define void @fold_ext_trunc(ptr %pa, ptr %pc) {
+; CHECK-LABEL: fold_ext_trunc:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movzwl (%rdi), %eax
+; CHECK-NEXT:    movw %ax, (%rsi)
+; CHECK-NEXT:    retq
+  %a = load bfloat, ptr %pa
+  %ext = fpext bfloat %a to float
+  %trunc = fptrunc float %ext to bfloat
+  store bfloat %trunc, ptr %pc
+  ret void
+}


        


More information about the llvm-commits mailing list