[Mlir-commits] [mlir] [Arith][Transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 13:35:31 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 40d080d83..665069049 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -36,7 +36,7 @@ static Value createConst(Location loc, Type type, int value,
/// Create an float constant.
static Value createFloatConst(Location loc, Type type, float value,
- PatternRewriter &rewriter) {
+ PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
@@ -389,7 +389,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
/// Table of representable values in F4E2M1:
///
/// Note: x is sign bit
-/// | Binary | Value ( + / - )
+/// | Binary | Value ( + / - )
/// | x000 | 0.0
/// | x001 | 0.5
/// | x010 | 1.0
@@ -399,12 +399,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
/// | x110 | 4.0
/// | x111 | 6.0
///
-/// Conversion procedure:
+/// Conversion procedure:
/// Step 1: Clamp to representable bounds.
/// Step 2: Convert exponent by adjusting bias.
/// Step 3: Set mantissa to first bit.
/// Step 4: Special consideration for subnormal and zero exponent.
-/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+/// subnormal.
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -427,38 +428,48 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value cF32MantissaWidth =
+ createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32FirstBitMask =
+ createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
- Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+ Value cF32MantissaMask =
+ createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
- Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
-
+ Value cF32Last22BitMask =
+ createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
+ ;
+
// Step 1: Clamp to bounds.
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
- Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
- Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
- Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
- operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+ Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT,
+ operand, cHigherBound);
+ Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand,
+ cLowerBound);
+ Value operandClamped =
+ b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+ operandClamped =
+ b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
-
+
// Step 2: Convert exponent by adjusting bias.
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
- Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+ Value biasAdjustedSignExp =
+ b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
-
+
// Step 3: Set mantissa to first bit.
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
-
+
// Step 4: Special consideration for conversion to 0.5.
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
Value isSubnormal =
@@ -466,25 +477,26 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value isNegOneExp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
- Value isNonZeroMan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+ man23Bits, c0x00000000);
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
- Value isZeroExp =
+ Value isZeroExp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-
- Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+
+ Value subResult =
+ b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
// Step 5: Round up if necessary.
- Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+ Value cRound =
+ createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
- shouldRound =
- b.create<arith::OrIOp>(shouldRound, isSubnormal);
+ shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
``````````
</details>
https://github.com/llvm/llvm-project/pull/144157
More information about the Mlir-commits
mailing list