[Mlir-commits] [mlir] [Arith][Transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 13:35:31 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 40d080d83..665069049 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -36,7 +36,7 @@ static Value createConst(Location loc, Type type, int value,
 
 /// Create an float constant.
 static Value createFloatConst(Location loc, Type type, float value,
-                         PatternRewriter &rewriter) {
+                              PatternRewriter &rewriter) {
   auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     return rewriter.create<arith::ConstantOp>(
@@ -389,7 +389,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// Table of representable values in F4E2M1:
 ///
 /// Note: x is sign bit
-/// | Binary | Value ( + / - ) 
+/// | Binary | Value ( + / - )
 /// | x000   | 0.0
 /// | x001   | 0.5
 /// | x010   | 1.0
@@ -399,12 +399,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// | x110   | 4.0
 /// | x111   | 6.0
 ///
-/// Conversion procedure: 
+/// Conversion procedure:
 ///   Step 1: Clamp to representable bounds.
 ///   Step 2: Convert exponent by adjusting bias.
 ///   Step 3: Set mantissa to first bit.
 ///   Step 4: Special consideration for subnormal and zero exponent.
-///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+///   subnormal.
 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -427,38 +428,48 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
     // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
-    Value cF4MantissaWidth = c0x1;         // 1
-    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value cF32MantissaWidth =
+        createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value cF4MantissaWidth = c0x1;                      // 1
+    Value cF32FirstBitMask =
+        createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
     Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value cF32MantissaMask =
+        createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
     Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
-    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
-    
+    Value cF32Last22BitMask =
+        createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
+    ;
+
     // Step 1: Clamp to bounds.
     Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
     Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
-    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
-    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
-    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
-    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT,
+                                              operand, cHigherBound);
+    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand,
+                                             cLowerBound);
+    Value operandClamped =
+        b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+    operandClamped =
+        b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
-  
+
     // Step 2: Convert exponent by adjusting bias.
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
     Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value biasAdjustedSignExp =
+        b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
     Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
     f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
-    
+
     // Step 3: Set mantissa to first bit.
     Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
     man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
-    
+
     // Step 4: Special consideration for conversion to 0.5.
     Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
     Value isSubnormal =
@@ -466,25 +477,26 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Value isNegOneExp =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
     Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
-    Value isNonZeroMan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+                                                 man23Bits, c0x00000000);
     Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
     Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
     Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
-    Value isZeroExp = 
+    Value isZeroExp =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-    
-    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+
+    Value subResult =
+        b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
     subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
     f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
 
     // Step 5: Round up if necessary.
-    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+    Value cRound =
+        createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-    shouldRound =
-        b.create<arith::OrIOp>(shouldRound, isSubnormal);
+    shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
     Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
     f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/144157


More information about the Mlir-commits mailing list