[llvm] [InstCombine] Only fold bitcast(fptrunc) if destination type matches fptrunc result type. (PR #77046)

Victor Mustya via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 4 20:34:06 PST 2024


https://github.com/vmustya created https://github.com/llvm/llvm-project/pull/77046

It's not enough to just make sure destination type is floating point,
because the following chain may be incorrectly optimized:
```LLVM
  %trunc = fptrunc float %src to bfloat
  %cast = bitcast bfloat %trunc to half
```
Before the fix, the instruction sequence mentioned above used to be
translated into single fptrunc instruction as follows:
```LLVM
  %trunc = fptrunc float %src to half
```

Such transformation was semantically incorrect.

>From 1a2b051942ca2aaaf994440d001585c9a2e08aca Mon Sep 17 00:00:00 2001
From: Victor Mustya <victor.mustya at intel.com>
Date: Thu, 4 Jan 2024 20:18:44 -0800
Subject: [PATCH] [InstCombine] Only fold bitcast(fptrunc) if destination type
 matches fptrunc result type.

It's not enough to just make sure destination type is floating point,
because the following chain may be incorrectly optimized:
```LLVM
  %trunc = fptrunc float %src to bfloat
  %cast = bitcast bfloat %trunc to half
```
Before the fix, the instruction sequence mentioned above used to be
translated into single fptrunc instruction as follows:
```LLVM
  %trunc = fptrunc float %src to half
```

Such transformation was semantically incorrect.
---
 llvm/lib/IR/Instructions.cpp                |  4 ++--
 llvm/test/Transforms/InstCombine/fptrunc.ll | 13 +++++++++++++
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 299b4e74677dcc..87874c3abc4680 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3203,8 +3203,8 @@ unsigned CastInst::isEliminableCastPair(
       return 0;
     case 4:
       // No-op cast in second op implies firstOp as long as the DestTy
-      // is floating point.
-      if (DstTy->isFloatingPointTy())
+      // matches MidTy.
+      if (DstTy == MidTy)
         return firstOp;
       return 0;
     case 5:
diff --git a/llvm/test/Transforms/InstCombine/fptrunc.ll b/llvm/test/Transforms/InstCombine/fptrunc.ll
index d3e153f12106e0..c78df0b83d9cdf 100644
--- a/llvm/test/Transforms/InstCombine/fptrunc.ll
+++ b/llvm/test/Transforms/InstCombine/fptrunc.ll
@@ -190,3 +190,16 @@ define half @ItoFtoF_u25_f32_f16(i25 %i) {
   %r = fptrunc float %x to half
   ret half %r
 }
+
+; Negative test - bitcast bfloat to half is not optimized
+
+define half @fptrunc_to_bfloat_bitcast_to_half(float %src) {
+; CHECK-LABEL: @fptrunc_to_bfloat_bitcast_to_half(
+; CHECK-NEXT:    [[TRUNC:%.*]] = fptrunc float [[SRC:%.*]] to bfloat
+; CHECK-NEXT:    [[CAST:%.*]] = bitcast bfloat [[TRUNC]] to half
+; CHECK-NEXT:    ret half [[CAST]]
+;
+  %trunc = fptrunc float %src to bfloat
+  %cast = bitcast bfloat %trunc to half
+  ret half %cast
+}



More information about the llvm-commits mailing list