[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 16 10:12:38 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/144157

>From 6ef31a776b020ea5fb34ff6882aba7fe67ff421b Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 9 Jun 2025 21:06:41 +0000
Subject: [PATCH 1/4] initial implementation to fix

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 206 ++++++++++++++++++
 1 file changed, 206 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..87483ced8e5cf 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -322,6 +322,57 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+    
+    // create constants to extract mantissa / exponent
+    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+    // create constants for NaNs
+    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -365,6 +416,161 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     return success();
   }
 };
+/*
+Conversion from F32 to F4E2M1 according to the OCP Spec:
+www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+
+The spec requiers us to perform Round to Nearest, Ties to Even.
+
+This means that after rounding, we should break ties by choosing the option
+which results in a mantissa of 0 in the least significant digit.
+
+Table of representable values in F4E2M1:
+
+Note: x is sign bit
+| Binary | Value ( + / - ) 
+| x000   | 0.0
+| x001   | 0.5
+| x010   | 1.0
+| x011   | 1.5
+| x100   | 2.0
+| x101   | 3.0
+| x110   | 4.0
+| x111   | 6.0
+
+Conversion procedure: 
+Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
+Create bias adjusted exponent, E_1 <- E_0 - 126
+If E_0 <= 0111 1110
+    M_1 <- 0, E_1 <- 00
+    end
+if E_1 == 00 (special case for almost subnormal)
+    if we must round up (M_0 >= 10000000000000000000000)
+        M_1 <- 0 
+        E_1 <- 01
+    else
+        M_1 <- 1
+    end
+Else if E_1 > 00
+    roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
+    if roundToEven
+        M_1 <- 0 
+    else
+        M_1 <- 1
+    If M_0 >= 11000000000000000000000
+        increment E_1
+    If E_1 > 11 (saturate if beyond range)
+        M_1 <- 1, E_1 <- 11
+end
+*/
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    // Constants
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+    Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+    Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
+    Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+    Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+  
+    Value cF32MantissaWidth = c0x00000017; // 23
+    Value cF4MantissaWidth = c0x1; // 1
+    Value cF32SignExpWidth = c0x00000009; // 9
+    Value cF32MantissaMask = c0x007fffff;
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    
+    Value cSubnormalExp = c0x7e; // 126
+
+    // Regular case
+    Value biasAdjustment = c0x7e; // 126
+    Value cRoundUp = c0x00600000; // 110 0000...
+    Value cRoundDown = c0x00200000; // 010 0000...
+    Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+    Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
+    Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
+    // If we round up or down to even, set mantissa to 0
+    Value shouldRoundUp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
+    Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
+                                                    man23Bits, cRoundDown);
+    // dont need to worry about saturation this way
+    f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
+    Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+    Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
+    f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
+
+    // Bordering subnormal
+    Value cSubnormalRoundUp =
+        createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
+    Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
+    Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
+    Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
+                                                  man23Bits, cSubnormalRoundUp);
+    f4EdgeRounded =
+        b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
+    Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
+                                           cSubnormalExp);
+
+    // Subnormal
+    Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
+    Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
+                                          cSubnormalExp);
+
+    // create constants to extract mantissa / exponent
+    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+    // create constants for NaNs
+    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
 
 /*
 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type

>From abc108e9f40adfcb1ac5b30680626ea5f1a35d33 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 12 Jun 2025 16:26:50 +0000
Subject: [PATCH 2/4] intermediate commit

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 79 ++++++++++++++++---
 1 file changed, 66 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 87483ced8e5cf..0fb802b82fffb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+                         PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
   if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -439,6 +451,13 @@ Note: x is sign bit
 | x111   | 6.0
 
 Conversion procedure: 
+
+Step 1: Clamp to max f4 value
+
+Step 2: convert exponent, if signed int comparison <= 0, set 0
+
+Step 3: if mantissa[1:] greater than 1000000, add 1
+
 Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
 Create bias adjusted exponent, E_1 <- E_0 - 126
 If E_0 <= 0111 1110
@@ -485,32 +504,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
     // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
     Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+    Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+    
     Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
     Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
     Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
     Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+    Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
     Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
-    
-    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
-  
     Value cF32MantissaWidth = c0x00000017; // 23
-    Value cF4MantissaWidth = c0x1; // 1
-    Value cF32SignExpWidth = c0x00000009; // 9
+    Value cF4MantissaWidth = c0x1;         // 1
+    Value cF32SignExpWidth = c0x00000009;  // 9
+    Value cF32FirstBitMask = c0x00400000;
+    Value cF32Last22BitMask = c0x003fffff;
     Value cF32MantissaMask = c0x007fffff;
+
+    // Step 1: Clamp to bounds.
+    Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+    Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+  
+    // Step 2: Convert exponent by adjusting bias.
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
-    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
-    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    Value biasAdjustment = c0x0000007e; // 126
+    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+    f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
     
-    Value cSubnormalExp = c0x7e; // 126
+    // Step 0: Special consideration for conversion to 0.5.
+    Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
+    Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
+    Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+
+    // Step 3: Set mantissa to first bit.
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+    Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+    Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+    // Step 4: Round up if necessary.
+    Value cRound = c0x00200000; // 010 0000...
+    Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+    Value shouldRound =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 
     // Regular case
-    Value biasAdjustment = c0x7e; // 126
-    Value cRoundUp = c0x00600000; // 110 0000...
-    Value cRoundDown = c0x00200000; // 010 0000...
-    Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+
+    
+
     Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
     Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
     // If we round up or down to even, set mantissa to 0

>From 63f2337c538f463f55104b4470aa88a07d364433 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 13 Jun 2025 13:06:00 -0700
Subject: [PATCH 3/4] Initial implementation of truncf

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 265 +++++-------------
 1 file changed, 66 insertions(+), 199 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 0fb802b82fffb..e35154e8ae32b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -334,57 +334,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
-struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(arith::ExtFOp op,
-                                PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value operand = op.getOperand();
-    Type operandTy = operand.getType();
-    Type resultTy = op.getType();
-    Type operandETy = getElementTypeOrSelf(operandTy);
-    Type resultETy = getElementTypeOrSelf(resultTy);
-
-    if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
-      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
-    }
-
-    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
-    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
-    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
-
-    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
-    
-    // create constants to extract mantissa / exponent
-    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
-    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
-    // create constants for NaNs
-    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
-    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
-    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
-    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
-    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
-    Value isNan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
-    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
-    if (resultETy.getIntOrFloatBitWidth() < 32) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
-    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
-      result = b.create<arith::ExtFOp>(resultTy, result);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -428,60 +377,34 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     return success();
   }
 };
-/*
-Conversion from F32 to F4E2M1 according to the OCP Spec:
-www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
-
-The spec requiers us to perform Round to Nearest, Ties to Even.
-
-This means that after rounding, we should break ties by choosing the option
-which results in a mantissa of 0 in the least significant digit.
-
-Table of representable values in F4E2M1:
-
-Note: x is sign bit
-| Binary | Value ( + / - ) 
-| x000   | 0.0
-| x001   | 0.5
-| x010   | 1.0
-| x011   | 1.5
-| x100   | 2.0
-| x101   | 3.0
-| x110   | 4.0
-| x111   | 6.0
-
-Conversion procedure: 
-
-Step 1: Clamp to max f4 value
-
-Step 2: convert exponent, if signed int comparison <= 0, set 0
-
-Step 3: if mantissa[1:] greater than 1000000, add 1
-
-Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
-Create bias adjusted exponent, E_1 <- E_0 - 126
-If E_0 <= 0111 1110
-    M_1 <- 0, E_1 <- 00
-    end
-if E_1 == 00 (special case for almost subnormal)
-    if we must round up (M_0 >= 10000000000000000000000)
-        M_1 <- 0 
-        E_1 <- 01
-    else
-        M_1 <- 1
-    end
-Else if E_1 > 00
-    roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
-    if roundToEven
-        M_1 <- 0 
-    else
-        M_1 <- 1
-    If M_0 >= 11000000000000000000000
-        increment E_1
-    If E_1 > 11 (saturate if beyond range)
-        M_1 <- 1, E_1 <- 11
-end
-*/
+
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - ) 
+/// | x000   | 0.0
+/// | x001   | 0.5
+/// | x010   | 1.0
+/// | x011   | 1.5
+/// | x100   | 2.0
+/// | x101   | 3.0
+/// | x110   | 4.0
+/// | x111   | 6.0
+///
+/// Conversion procedure: 
+///   Step 1: Clamp to representable bounds.
+///   Step 2: Convert exponent by adjusting bias.
+///   Step 3: Set mantissa to first bit.
+///   Step 4: Special consideration for subnormal and zero exponent.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -504,122 +427,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
     // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
-    Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-    
-    Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
-    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
-    Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
-    Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
-    Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
-    Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
-    Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
-    Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
-    Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
-    Value cF32MantissaWidth = c0x00000017; // 23
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
     Value cF4MantissaWidth = c0x1;         // 1
-    Value cF32SignExpWidth = c0x00000009;  // 9
-    Value cF32FirstBitMask = c0x00400000;
-    Value cF32Last22BitMask = c0x003fffff;
-    Value cF32MantissaMask = c0x007fffff;
-
+    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+    
     // Step 1: Clamp to bounds.
     Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
     Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
-    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
-    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
-    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
-    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
+    operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
   
     // Step 2: Convert exponent by adjusting bias.
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
-    Value biasAdjustment = c0x0000007e; // 126
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
     Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
     Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
     f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
     
-    // Step 0: Special consideration for conversion to 0.5.
-    Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
-    Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
-    Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
-    Value isSubnormal =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-
     // Step 3: Set mantissa to first bit.
-    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
-    Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+    Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+    
+    // Step 4: Special consideration for conversion to 0.5.
+    Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+    Value isNegOneExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value isNonZeroMan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+    Value isZeroExp = 
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+    
+    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+    f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
 
-    // Step 4: Round up if necessary.
-    Value cRound = c0x00200000; // 010 0000...
+    // Step 5: Round up if necessary.
+    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    shouldRound =
+        b.create<arith::OrIOp>(shouldRound, isSubnormal);
     Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
     f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 
-    // Regular case
-
-    
-
-    Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
-    Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
-    // If we round up or down to even, set mantissa to 0
-    Value shouldRoundUp =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
-    Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
-                                                    man23Bits, cRoundDown);
-    // dont need to worry about saturation this way
-    f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
-    Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
-    Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
-    f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
-    f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
-
-    // Bordering subnormal
-    Value cSubnormalRoundUp =
-        createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
-    Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
-    Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
-    Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
-                                                  man23Bits, cSubnormalRoundUp);
-    f4EdgeRounded =
-        b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
-    Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
-                                           cSubnormalExp);
-
-    // Subnormal
-    Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
-    Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
-                                          cSubnormalExp);
-
-    // create constants to extract mantissa / exponent
-    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
-    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
-    // create constants for NaNs
-    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
-    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
-    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
-    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
-    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
-    Value isNan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
-    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
-    if (resultETy.getIntOrFloatBitWidth() < 32) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
-    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
-      result = b.create<arith::ExtFOp>(resultTy, result);
-    }
+    Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
     rewriter.replaceOp(op, result);
     return success();
   }

>From 78d9e7a6031c4e947e288171822745feb9c507d1 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 10:06:34 -0700
Subject: [PATCH 4/4] PR Review round 1

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 64 ++++++++++---------
 1 file changed, 35 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index e35154e8ae32b..5f68710553cdc 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,9 +34,9 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
-/// Create an float constant.
-static Value createFloatConst(Location loc, Type type, float value,
-                         PatternRewriter &rewriter) {
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, APFloat value,
+                              PatternRewriter &rewriter) {
   auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     return rewriter.create<arith::ConstantOp>(
@@ -389,7 +389,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// Table of representable values in F4E2M1:
 ///
 /// Note: x is sign bit
-/// | Binary | Value ( + / - ) 
+/// | Binary | Value ( + / - )
 /// | x000   | 0.0
 /// | x001   | 0.5
 /// | x010   | 1.0
@@ -399,12 +399,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// | x110   | 4.0
 /// | x111   | 6.0
 ///
-/// Conversion procedure: 
+/// Conversion procedure:
 ///   Step 1: Clamp to representable bounds.
 ///   Step 2: Convert exponent by adjusting bias.
 ///   Step 3: Set mantissa to first bit.
 ///   Step 4: Special consideration for subnormal and zero exponent.
-///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+///   subnormal.
 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -416,8 +417,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
-      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
     }
 
     Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -425,64 +426,69 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
-    // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
-    Value cF4MantissaWidth = c0x1;         // 1
-    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
     Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
     Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
-    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
-    
+
     // Step 1: Clamp to bounds.
     Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
     Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
     Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
     operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
-  
+
     // Step 2: Convert exponent by adjusting bias.
-    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
     Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value cF4MantissaWidth = c0x1; // 1
+    Value cF32MantissaWidth =
+        createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value biasAdjustedSignExp =
+        b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
     Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
     f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
-    
+
     // Step 3: Set mantissa to first bit.
+    Value cF32FirstBitMask =
+        createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
     man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
-    
+
     // Step 4: Special consideration for conversion to 0.5.
+    Value cF32MantissaMask =
+        createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
     Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
     Value isSubnormal =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
     Value isNegOneExp =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
     Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
-    Value isNonZeroMan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+                                                 man23Bits, c0x00000000);
     Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+    Value isZeroExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+
     Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
     Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
-    Value isZeroExp = 
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-    
-    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    Value subResult =
+        b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
     subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
     f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
 
     // Step 5: Round up if necessary.
-    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+    Value cF32Last22BitMask =
+        createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
+    Value cRound =
+        createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-    shouldRound =
-        b.create<arith::OrIOp>(shouldRound, isSubnormal);
+    shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
     Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
     f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 



More information about the Mlir-commits mailing list