[Mlir-commits] [mlir] [mlir][arith][transforms] Fix f4E2M1FN to f32 cast (PR #160121)
Jorn Tuyls
llvmlistbot at llvm.org
Mon Sep 22 08:57:29 PDT 2025
https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/160121
>From 57ba6f887c116063961f4d525bdcff1fa36989fa Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Mon, 22 Sep 2025 09:01:10 -0500
Subject: [PATCH] [mlir][arith][transforms] Fix f4E2M1FN to f32 castw
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 11 +--
.../CPU/test-arith-expand-truncf-extf.mlir | 67 +++++++++++++++++++
2 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6e7421daeb223..adeb50b6da628 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+ Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
+
+ Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
// Set last Exponent bit and Mantissa.
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
- Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
+ Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
Value isHalf =
- arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,11 +405,11 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
Value useLargerExp =
- arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
Value bits25To31 =
arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
Value zeroExp =
- arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
// Set sign.
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 9c310d80d4c2d..f2970618d5b6e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -28,7 +28,18 @@ func.func @entry() {
%zero = arith.constant 0.0 : f32
%half = arith.constant 0.5 : f32
%one = arith.constant 1.0 : f32
+ %oneAndAHalf = arith.constant 1.5 : f32
+ %two = arith.constant 2.0 : f32
+ %three = arith.constant 3.0 : f32
+ %four = arith.constant 4.0 : f32
%max = arith.constant 6.0 : f32
+ %minZero = arith.constant -0.0 : f32
+ %minHalf = arith.constant -0.5 : f32
+ %minOne = arith.constant -1.0 : f32
+ %minOneAndAHalf = arith.constant -1.5 : f32
+ %minTwo = arith.constant -2.0 : f32
+ %minThree = arith.constant -3.0 : f32
+ %minFour = arith.constant -4.0 : f32
%min = arith.constant -6.0 : f32
%lowerThanMin = arith.constant -1000000.0 : f32
%higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +52,28 @@ func.func @entry() {
func.call @check_truncf(%half) : (f32) -> ()
// CHECK: 2
func.call @check_truncf(%one) : (f32) -> ()
+ // CHECK: 3
+ func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
+ // CHECK: 4
+ func.call @check_truncf(%two) : (f32) -> ()
+ // CHECK: 5
+ func.call @check_truncf(%three) : (f32) -> ()
+ // CHECK: 6
+ func.call @check_truncf(%four) : (f32) -> ()
// CHECK: 7
func.call @check_truncf(%max) : (f32) -> ()
+ // CHECK: 9
+ func.call @check_truncf(%minHalf) : (f32) -> ()
+ // CHECK: 10
+ func.call @check_truncf(%minOne) : (f32) -> ()
+ // CHECK: 11
+ func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
+ // CHECK: 12
+ func.call @check_truncf(%minTwo) : (f32) -> ()
+ // CHECK: 13
+ func.call @check_truncf(%minThree) : (f32) -> ()
+ // CHECK: 14
+ func.call @check_truncf(%minFour) : (f32) -> ()
// CHECK: 15
func.call @check_truncf(%min) : (f32) -> ()
// CHECK: 7
@@ -60,9 +91,45 @@ func.func @entry() {
// CHECK: 0.5
%halfF4 = arith.truncf %half : f32 to f4E2M1FN
func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+ // CHECK: 1
+ %oneF4 = arith.truncf %one : f32 to f4E2M1FN
+ func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
+ // CHECK: 1.5
+ %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
+ func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: 2
+ %twoF4 = arith.truncf %two : f32 to f4E2M1FN
+ func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
+ // CHECK: 3
+ %threeF4 = arith.truncf %three : f32 to f4E2M1FN
+ func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
+ // CHECK: 4
+ %fourF4 = arith.truncf %four : f32 to f4E2M1FN
+ func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
// CHECK: 6
%higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+ // CHECK: -0
+ %minZeroF4 = arith.truncf %minZero : f32 to f4E2M1FN
+ func.call @check_extf(%minZeroF4) : (f4E2M1FN) -> ()
+ // CHECK: -0.5
+ %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
+ func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: -1
+ %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
+ func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
+ // CHECK: -1.5
+ %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
+ func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: -2
+ %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
+ func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
+ // CHECK: -3
+ %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
+ func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
+ // CHECK: -4
+ %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
+ func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
// CHECK: -6
%lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
More information about the Mlir-commits
mailing list