[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Jun 17 12:20:57 PDT 2025
================
@@ -366,6 +446,139 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
}
};
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - )
+/// | x000 | 0.0
+/// | x001 | 0.5
+/// | x010 | 1.0
+/// | x011 | 1.5
+/// | x100 | 2.0
+/// | x101 | 3.0
+/// | x110 | 4.0
+/// | x111 | 6.0
+///
+/// Conversion procedure:
+/// Step 1: Clamp to representable bounds.
+/// Step 2: Convert exponent by adjusting bias.
+/// Step 3: Set mantissa to first bit.
+/// Step 4: Special consideration for subnormal and zero exponent.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+/// subnormal.
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!isa<Float32Type>(operandETy)) {
+ operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
+ }
+ if (!isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
+ Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+ Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+
+ // Step 0: Clamp to bounds.
+ Value cHigherBound =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
+ Value cLowerBound =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
+ Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
----------------
krzysz00 wrote:
I think we want the ones that lower to `llvm.minnum` and `llvm.maxnum` here - I forget which they are - so that we get NaN propagating correctly
https://github.com/llvm/llvm-project/pull/144157
More information about the Mlir-commits
mailing list