[Mlir-commits] [mlir] faf5f28 - [mlir][arith][transforms] Fix f4E2M1FN to f32 cast (#160121)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 22 09:15:30 PDT 2025


Author: Jorn Tuyls
Date: 2025-09-22T12:15:26-04:00
New Revision: faf5f28fc26430d6f0db1cdde1e9a24a1710309d

URL: https://github.com/llvm/llvm-project/commit/faf5f28fc26430d6f0db1cdde1e9a24a1710309d
DIFF: https://github.com/llvm/llvm-project/commit/faf5f28fc26430d6f0db1cdde1e9a24a1710309d.diff

LOG: [mlir][arith][transforms] Fix f4E2M1FN to f32 cast (#160121)

The signed i4 bitcast was used when setting the exponent and mantissa
and instead the sign should be omitted in the comparisons.

Without this, for example the following incorrect conversion from `-0.5`
f4 to `-3.0` f32 will happen:

| Binary | F4E2M1 | f32[23:32] | f32
| 1001   | -0.5    | ~~1 1000 000 01~~ | ~~-3.0~~

**Walkthrough:**
Bits 23 and 24 are set based on:
```
Value isHalf =
        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
```
Because `1001 (i4) != 1`, bit 23 and 24 are set to the leading two bits
of `1001 << 2`, which is `01`. The correct bits are `00`.

Bits 25 through 31 are set based on the i4 value being greater or equal
to 4:
```
Value useLargerExp =
        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
```
As `1001` is a negative i4 value, this is false and those bits are
incorrectly set to `1000 000` instead of `0111 111`.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6e7421daeb223..adeb50b6da628 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
     Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
     Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+    Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
+
+    Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
 
     // Set last Exponent bit and Mantissa.
     Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
-    Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
+    Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
     Value isHalf =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
     bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
     bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
     bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,11 +405,11 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
     Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
     Value useLargerExp =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
     Value bits25To31 =
         arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
     Value zeroExp =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
     bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
 
     // Set sign.

diff  --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 9c310d80d4c2d..f2970618d5b6e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -28,7 +28,18 @@ func.func @entry() {
   %zero = arith.constant 0.0 : f32
   %half = arith.constant 0.5 : f32
   %one = arith.constant 1.0 : f32
+  %oneAndAHalf = arith.constant 1.5 : f32
+  %two = arith.constant 2.0 : f32
+  %three = arith.constant 3.0 : f32
+  %four = arith.constant 4.0 : f32
   %max = arith.constant 6.0 : f32
+  %minZero = arith.constant -0.0 : f32
+  %minHalf = arith.constant -0.5 : f32
+  %minOne = arith.constant -1.0 : f32
+  %minOneAndAHalf = arith.constant -1.5 : f32
+  %minTwo = arith.constant -2.0 : f32
+  %minThree = arith.constant -3.0 : f32
+  %minFour = arith.constant -4.0 : f32
   %min = arith.constant -6.0 : f32
   %lowerThanMin = arith.constant -1000000.0 : f32
   %higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +52,28 @@ func.func @entry() {
   func.call @check_truncf(%half) : (f32) -> ()
   // CHECK: 2
   func.call @check_truncf(%one) : (f32) -> ()
+  // CHECK: 3
+  func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
+  // CHECK: 4
+  func.call @check_truncf(%two) : (f32) -> ()
+  // CHECK: 5
+  func.call @check_truncf(%three) : (f32) -> ()
+  // CHECK: 6
+  func.call @check_truncf(%four) : (f32) -> ()
   // CHECK: 7
   func.call @check_truncf(%max) : (f32) -> ()
+  // CHECK: 9
+  func.call @check_truncf(%minHalf) : (f32) -> ()
+  // CHECK: 10
+  func.call @check_truncf(%minOne) : (f32) -> ()
+  // CHECK: 11
+  func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
+  // CHECK: 12
+  func.call @check_truncf(%minTwo) : (f32) -> ()
+  // CHECK: 13
+  func.call @check_truncf(%minThree) : (f32) -> ()
+  // CHECK: 14
+  func.call @check_truncf(%minFour) : (f32) -> ()
   // CHECK: 15
   func.call @check_truncf(%min) : (f32) -> ()
   // CHECK: 7
@@ -60,9 +91,45 @@ func.func @entry() {
   // CHECK: 0.5
   %halfF4 = arith.truncf %half : f32 to f4E2M1FN
   func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+  // CHECK: 1
+  %oneF4 = arith.truncf %one : f32 to f4E2M1FN
+  func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
+  // CHECK: 1.5
+  %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
+  func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: 2
+  %twoF4 = arith.truncf %two : f32 to f4E2M1FN
+  func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
+  // CHECK: 3
+  %threeF4 = arith.truncf %three : f32 to f4E2M1FN
+  func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
+  // CHECK: 4
+  %fourF4 = arith.truncf %four : f32 to f4E2M1FN
+  func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
   // CHECK: 6
   %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
   func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+  // CHECK: -0
+  %minZeroF4 = arith.truncf %minZero : f32 to f4E2M1FN
+  func.call @check_extf(%minZeroF4) : (f4E2M1FN) -> ()
+  // CHECK: -0.5
+  %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
+  func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: -1
+  %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
+  func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
+  // CHECK: -1.5
+  %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
+  func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: -2
+  %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
+  func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
+  // CHECK: -3
+  %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
+  func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
+  // CHECK: -4
+  %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
+  func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
   // CHECK: -6
   %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
   func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()


        


More information about the Mlir-commits mailing list