[Mlir-commits] [mlir] [Arith][Transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 13:33:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Muzammil (Muzammiluddin-Syed-ECE)

<details>
<summary>Changes</summary>

See work detail: https://github.com/iree-org/iree/issues/20920

Add support for FP32 -> MXFP F4 in `arith.truncf` op

---
Full diff: https://github.com/llvm/llvm-project/pull/144157.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+128) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..40d080d83ad38 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+                         PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
   if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -366,6 +378,122 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   }
 };
 
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - ) 
+/// | x000   | 0.0
+/// | x001   | 0.5
+/// | x010   | 1.0
+/// | x011   | 1.5
+/// | x100   | 2.0
+/// | x101   | 3.0
+/// | x110   | 4.0
+/// | x111   | 6.0
+///
+/// Conversion procedure: 
+///   Step 1: Clamp to representable bounds.
+///   Step 2: Convert exponent by adjusting bias.
+///   Step 3: Set mantissa to first bit.
+///   Step 4: Special consideration for subnormal and zero exponent.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    // Constants
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value cF4MantissaWidth = c0x1;         // 1
+    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+    
+    // Step 1: Clamp to bounds.
+    Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+    Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+  
+    // Step 2: Convert exponent by adjusting bias.
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+    f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
+    
+    // Step 3: Set mantissa to first bit.
+    Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
+    Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+    Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+    
+    // Step 4: Special consideration for conversion to 0.5.
+    Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+    Value isNegOneExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value isNonZeroMan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+    Value isZeroExp = 
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+    
+    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+    f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+
+    // Step 5: Round up if necessary.
+    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+    Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+    Value shouldRound =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    shouldRound =
+        b.create<arith::OrIOp>(shouldRound, isSubnormal);
+    Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+
+    Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /*
 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
 Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,

``````````

</details>


https://github.com/llvm/llvm-project/pull/144157


More information about the Mlir-commits mailing list