[Mlir-commits] [mlir] [mlir][arith][transforms] Fix f4E2M1FN to f32 cast (PR #160121)

Jorn Tuyls llvmlistbot at llvm.org
Mon Sep 22 08:57:29 PDT 2025


https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/160121

>From 57ba6f887c116063961f4d525bdcff1fa36989fa Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Mon, 22 Sep 2025 09:01:10 -0500
Subject: [PATCH] [mlir][arith][transforms] Fix f4E2M1FN to f32 castw

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 11 +--
 .../CPU/test-arith-expand-truncf-extf.mlir    | 67 +++++++++++++++++++
 2 files changed, 74 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6e7421daeb223..adeb50b6da628 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
     Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
     Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+    Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
+
+    Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
 
     // Set last Exponent bit and Mantissa.
     Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
-    Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
+    Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
     Value isHalf =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
     bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
     bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
     bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,11 +405,11 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
     Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
     Value useLargerExp =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
     Value bits25To31 =
         arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
     Value zeroExp =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
     bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
 
     // Set sign.
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 9c310d80d4c2d..f2970618d5b6e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -28,7 +28,18 @@ func.func @entry() {
   %zero = arith.constant 0.0 : f32
   %half = arith.constant 0.5 : f32
   %one = arith.constant 1.0 : f32
+  %oneAndAHalf = arith.constant 1.5 : f32
+  %two = arith.constant 2.0 : f32
+  %three = arith.constant 3.0 : f32
+  %four = arith.constant 4.0 : f32
   %max = arith.constant 6.0 : f32
+  %minZero = arith.constant -0.0 : f32
+  %minHalf = arith.constant -0.5 : f32
+  %minOne = arith.constant -1.0 : f32
+  %minOneAndAHalf = arith.constant -1.5 : f32
+  %minTwo = arith.constant -2.0 : f32
+  %minThree = arith.constant -3.0 : f32
+  %minFour = arith.constant -4.0 : f32
   %min = arith.constant -6.0 : f32
   %lowerThanMin = arith.constant -1000000.0 : f32
   %higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +52,28 @@ func.func @entry() {
   func.call @check_truncf(%half) : (f32) -> ()
   // CHECK: 2
   func.call @check_truncf(%one) : (f32) -> ()
+  // CHECK: 3
+  func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
+  // CHECK: 4
+  func.call @check_truncf(%two) : (f32) -> ()
+  // CHECK: 5
+  func.call @check_truncf(%three) : (f32) -> ()
+  // CHECK: 6
+  func.call @check_truncf(%four) : (f32) -> ()
   // CHECK: 7
   func.call @check_truncf(%max) : (f32) -> ()
+  // CHECK: 9
+  func.call @check_truncf(%minHalf) : (f32) -> ()
+  // CHECK: 10
+  func.call @check_truncf(%minOne) : (f32) -> ()
+  // CHECK: 11
+  func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
+  // CHECK: 12
+  func.call @check_truncf(%minTwo) : (f32) -> ()
+  // CHECK: 13
+  func.call @check_truncf(%minThree) : (f32) -> ()
+  // CHECK: 14
+  func.call @check_truncf(%minFour) : (f32) -> ()
   // CHECK: 15
   func.call @check_truncf(%min) : (f32) -> ()
   // CHECK: 7
@@ -60,9 +91,45 @@ func.func @entry() {
   // CHECK: 0.5
   %halfF4 = arith.truncf %half : f32 to f4E2M1FN
   func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+  // CHECK: 1
+  %oneF4 = arith.truncf %one : f32 to f4E2M1FN
+  func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
+  // CHECK: 1.5
+  %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
+  func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: 2
+  %twoF4 = arith.truncf %two : f32 to f4E2M1FN
+  func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
+  // CHECK: 3
+  %threeF4 = arith.truncf %three : f32 to f4E2M1FN
+  func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
+  // CHECK: 4
+  %fourF4 = arith.truncf %four : f32 to f4E2M1FN
+  func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
   // CHECK: 6
   %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
   func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+  // CHECK: -0
+  %minZeroF4 = arith.truncf %minZero : f32 to f4E2M1FN
+  func.call @check_extf(%minZeroF4) : (f4E2M1FN) -> ()
+  // CHECK: -0.5
+  %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
+  func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: -1
+  %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
+  func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
+  // CHECK: -1.5
+  %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
+  func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: -2
+  %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
+  func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
+  // CHECK: -3
+  %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
+  func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
+  // CHECK: -4
+  %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
+  func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
   // CHECK: -6
   %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
   func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()



More information about the Mlir-commits mailing list