[Mlir-commits] [mlir] [mlir][arith][transforms] Fix f4E2M1FN to f32 cast (PR #160121)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 07:36:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Jorn Tuyls (jtuyls)
<details>
<summary>Changes</summary>
The signed i4 bitcast was used when setting the exponent and mantissa and instead the sign should be omitted in the comparisons.
Without this, for example the following incorrect conversion from `-0.5` f4 to `-3.0` f32 will happen:
| Binary | F4E2M1 | f32[23:32] | f32
| 1001 | -0.5 | ~~1 1000 000 01~~ | ~~-3.0~~
**Walkthrough:**
Bits 23 and 24 are set based on:
```
Value isHalf =
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
```
Because `1001 (i4) != 1`, bit 23 and 24 are set to the leading two bits of `1001 << 2`, which is `01`.
Bits 25 through 31 are set based on the i4 value being larger than 4:
```
Value useLargerExp =
arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
```
As `1001` is a negative i4 value, this is false and those bits are incorrectly set to `1000 000` instead of `0111 111`.
---
Full diff: https://github.com/llvm/llvm-project/pull/160121.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+6-3)
- (modified) mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir (+63)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6e7421daeb223..54307c9ac843b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+ Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
+
+ Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
// Set last Exponent bit and Mantissa.
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
- Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
+ Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
Value isHalf =
- arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,7 +405,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
Value useLargerExp =
- arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
Value bits25To31 =
arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
Value zeroExp =
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 9c310d80d4c2d..f58e65a04589e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -28,7 +28,17 @@ func.func @entry() {
%zero = arith.constant 0.0 : f32
%half = arith.constant 0.5 : f32
%one = arith.constant 1.0 : f32
+ %oneAndAHalf = arith.constant 1.5 : f32
+ %two = arith.constant 2.0 : f32
+ %three = arith.constant 3.0 : f32
+ %four = arith.constant 4.0 : f32
%max = arith.constant 6.0 : f32
+ %minHalf = arith.constant -0.5 : f32
+ %minOne = arith.constant -1.0 : f32
+ %minOneAndAHalf = arith.constant -1.5 : f32
+ %minTwo = arith.constant -2.0 : f32
+ %minThree = arith.constant -3.0 : f32
+ %minFour = arith.constant -4.0 : f32
%min = arith.constant -6.0 : f32
%lowerThanMin = arith.constant -1000000.0 : f32
%higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +51,28 @@ func.func @entry() {
func.call @check_truncf(%half) : (f32) -> ()
// CHECK: 2
func.call @check_truncf(%one) : (f32) -> ()
+ // CHECK: 3
+ func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
+ // CHECK: 4
+ func.call @check_truncf(%two) : (f32) -> ()
+ // CHECK: 5
+ func.call @check_truncf(%three) : (f32) -> ()
+ // CHECK: 6
+ func.call @check_truncf(%four) : (f32) -> ()
// CHECK: 7
func.call @check_truncf(%max) : (f32) -> ()
+ // CHECK: 9
+ func.call @check_truncf(%minHalf) : (f32) -> ()
+ // CHECK: 10
+ func.call @check_truncf(%minOne) : (f32) -> ()
+ // CHECK: 11
+ func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
+ // CHECK: 12
+ func.call @check_truncf(%minTwo) : (f32) -> ()
+ // CHECK: 13
+ func.call @check_truncf(%minThree) : (f32) -> ()
+ // CHECK: 14
+ func.call @check_truncf(%minFour) : (f32) -> ()
// CHECK: 15
func.call @check_truncf(%min) : (f32) -> ()
// CHECK: 7
@@ -60,9 +90,42 @@ func.func @entry() {
// CHECK: 0.5
%halfF4 = arith.truncf %half : f32 to f4E2M1FN
func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+ // CHECK: 1
+ %oneF4 = arith.truncf %one : f32 to f4E2M1FN
+ func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
+ // CHECK: 1.5
+ %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
+ func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: 2
+ %twoF4 = arith.truncf %two : f32 to f4E2M1FN
+ func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
+ // CHECK: 3
+ %threeF4 = arith.truncf %three : f32 to f4E2M1FN
+ func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
+ // CHECK: 4
+ %fourF4 = arith.truncf %four : f32 to f4E2M1FN
+ func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
// CHECK: 6
%higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+ // CHECK: -0.5
+ %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
+ func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: -1
+ %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
+ func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
+ // CHECK: -1.5
+ %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
+ func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: -2
+ %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
+ func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
+ // CHECK: -3
+ %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
+ func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
+ // CHECK: -4
+ %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
+ func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
// CHECK: -6
%lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/160121
More information about the Mlir-commits
mailing list