[llvm] [InstCombine] Only fold bitcast(fptrunc) if destination type matches fptrunc result type. (PR #77046)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 4 20:34:49 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Victor Mustya (vmustya)

<details>
<summary>Changes</summary>

It's not enough to just make sure destination type is floating point,
because the following chain may be incorrectly optimized:
```LLVM
  %trunc = fptrunc float %src to bfloat
  %cast = bitcast bfloat %trunc to half
```
Before the fix, the instruction sequence mentioned above used to be
translated into single fptrunc instruction as follows:
```LLVM
  %trunc = fptrunc float %src to half
```

Such transformation was semantically incorrect.

---
Full diff: https://github.com/llvm/llvm-project/pull/77046.diff


2 Files Affected:

- (modified) llvm/lib/IR/Instructions.cpp (+2-2) 
- (modified) llvm/test/Transforms/InstCombine/fptrunc.ll (+13) 


``````````diff
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 299b4e74677dcc..87874c3abc4680 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3203,8 +3203,8 @@ unsigned CastInst::isEliminableCastPair(
       return 0;
     case 4:
       // No-op cast in second op implies firstOp as long as the DestTy
-      // is floating point.
-      if (DstTy->isFloatingPointTy())
+      // matches MidTy.
+      if (DstTy == MidTy)
         return firstOp;
       return 0;
     case 5:
diff --git a/llvm/test/Transforms/InstCombine/fptrunc.ll b/llvm/test/Transforms/InstCombine/fptrunc.ll
index d3e153f12106e0..c78df0b83d9cdf 100644
--- a/llvm/test/Transforms/InstCombine/fptrunc.ll
+++ b/llvm/test/Transforms/InstCombine/fptrunc.ll
@@ -190,3 +190,16 @@ define half @ItoFtoF_u25_f32_f16(i25 %i) {
   %r = fptrunc float %x to half
   ret half %r
 }
+
+; Negative test - bitcast bfloat to half is not optimized
+
+define half @fptrunc_to_bfloat_bitcast_to_half(float %src) {
+; CHECK-LABEL: @fptrunc_to_bfloat_bitcast_to_half(
+; CHECK-NEXT:    [[TRUNC:%.*]] = fptrunc float [[SRC:%.*]] to bfloat
+; CHECK-NEXT:    [[CAST:%.*]] = bitcast bfloat [[TRUNC]] to half
+; CHECK-NEXT:    ret half [[CAST]]
+;
+  %trunc = fptrunc float %src to bfloat
+  %cast = bitcast bfloat %trunc to half
+  ret half %cast
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/77046


More information about the llvm-commits mailing list