[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 16 10:08:33 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/144157
>From 6ef31a776b020ea5fb34ff6882aba7fe67ff421b Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 9 Jun 2025 21:06:41 +0000
Subject: [PATCH 1/4] initial implementation to fix
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 206 ++++++++++++++++++
1 file changed, 206 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..87483ced8e5cf 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -322,6 +322,57 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ // create constants to extract mantissa / exponent
+ Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+ Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+ // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+ // create constants for NaNs
+ Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+ Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -365,6 +416,161 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return success();
}
};
+/*
+Conversion from F32 to F4E2M1 according to the OCP Spec:
+www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+
+The spec requiers us to perform Round to Nearest, Ties to Even.
+
+This means that after rounding, we should break ties by choosing the option
+which results in a mantissa of 0 in the least significant digit.
+
+Table of representable values in F4E2M1:
+
+Note: x is sign bit
+| Binary | Value ( + / - )
+| x000 | 0.0
+| x001 | 0.5
+| x010 | 1.0
+| x011 | 1.5
+| x100 | 2.0
+| x101 | 3.0
+| x110 | 4.0
+| x111 | 6.0
+
+Conversion procedure:
+Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
+Create bias adjusted exponent, E_1 <- E_0 - 126
+If E_0 <= 0111 1110
+ M_1 <- 0, E_1 <- 00
+ end
+if E_1 == 00 (special case for almost subnormal)
+ if we must round up (M_0 >= 10000000000000000000000)
+ M_1 <- 0
+ E_1 <- 01
+ else
+ M_1 <- 1
+ end
+Else if E_1 > 00
+ roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
+ if roundToEven
+ M_1 <- 0
+ else
+ M_1 <- 1
+ If M_0 >= 11000000000000000000000
+ increment E_1
+ If E_1 > 11 (saturate if beyond range)
+ M_1 <- 1, E_1 <- 11
+end
+*/
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ // Constants
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+ Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+ Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
+ Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+ Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+
+ Value cF32MantissaWidth = c0x00000017; // 23
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32SignExpWidth = c0x00000009; // 9
+ Value cF32MantissaMask = c0x007fffff;
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+ Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+
+ Value cSubnormalExp = c0x7e; // 126
+
+ // Regular case
+ Value biasAdjustment = c0x7e; // 126
+ Value cRoundUp = c0x00600000; // 110 0000...
+ Value cRoundDown = c0x00200000; // 010 0000...
+ Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+ Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
+ Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
+ // If we round up or down to even, set mantissa to 0
+ Value shouldRoundUp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
+ Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
+ man23Bits, cRoundDown);
+ // dont need to worry about saturation this way
+ f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
+ Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+ Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
+ f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
+ f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
+
+ // Bordering subnormal
+ Value cSubnormalRoundUp =
+ createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
+ Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
+ Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
+ Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
+ man23Bits, cSubnormalRoundUp);
+ f4EdgeRounded =
+ b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
+ Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
+ cSubnormalExp);
+
+ // Subnormal
+ Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
+ Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
+ cSubnormalExp);
+
+ // create constants to extract mantissa / exponent
+ Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+ Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+ // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+ // create constants for NaNs
+ Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+ Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
/*
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
>From abc108e9f40adfcb1ac5b30680626ea5f1a35d33 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 12 Jun 2025 16:26:50 +0000
Subject: [PATCH 2/4] intermediate commit
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 79 ++++++++++++++++---
1 file changed, 66 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 87483ced8e5cf..0fb802b82fffb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+ PatternRewriter &rewriter) {
+ auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(shapedTy, attr));
+ }
+
+ return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -439,6 +451,13 @@ Note: x is sign bit
| x111 | 6.0
Conversion procedure:
+
+Step 1: Clamp to max f4 value
+
+Step 2: convert exponent, if signed int comparison <= 0, set 0
+
+Step 3: if mantissa[1:] greater than 1000000, add 1
+
Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
Create bias adjusted exponent, E_1 <- E_0 - 126
If E_0 <= 0111 1110
@@ -485,32 +504,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+ Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+
Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+ Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+ Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
-
- Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
-
Value cF32MantissaWidth = c0x00000017; // 23
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32SignExpWidth = c0x00000009; // 9
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32SignExpWidth = c0x00000009; // 9
+ Value cF32FirstBitMask = c0x00400000;
+ Value cF32Last22BitMask = c0x003fffff;
Value cF32MantissaMask = c0x007fffff;
+
+ // Step 1: Clamp to bounds.
+ Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+ Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+ Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+ Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+ Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+ operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+
+ // Step 2: Convert exponent by adjusting bias.
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
- Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+ Value biasAdjustment = c0x0000007e; // 126
+ Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+ Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+ f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
- Value cSubnormalExp = c0x7e; // 126
+ // Step 0: Special consideration for conversion to 0.5.
+ Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
+ Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
+ Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
+ Value isSubnormal =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+
+ // Step 3: Set mantissa to first bit.
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+ Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+ Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+ Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+ // Step 4: Round up if necessary.
+ Value cRound = c0x00200000; // 010 0000...
+ Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+ Value shouldRound =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+ Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+ f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
// Regular case
- Value biasAdjustment = c0x7e; // 126
- Value cRoundUp = c0x00600000; // 110 0000...
- Value cRoundDown = c0x00200000; // 010 0000...
- Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+
+
+
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
// If we round up or down to even, set mantissa to 0
>From 63f2337c538f463f55104b4470aa88a07d364433 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 13 Jun 2025 13:06:00 -0700
Subject: [PATCH 3/4] Initial implementation of truncf
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 265 +++++-------------
1 file changed, 66 insertions(+), 199 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 0fb802b82fffb..e35154e8ae32b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -334,57 +334,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
-struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const final {
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value operand = op.getOperand();
- Type operandTy = operand.getType();
- Type resultTy = op.getType();
- Type operandETy = getElementTypeOrSelf(operandTy);
- Type resultETy = getElementTypeOrSelf(resultTy);
-
- if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
- return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
- }
-
- Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
- Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
-
- Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
-
- // create constants to extract mantissa / exponent
- Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
- Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
- // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
- // create constants for NaNs
- Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
- Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
- Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
- Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
- Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
- Value isNan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
- // select for NaNs
- f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result);
- } else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -428,60 +377,34 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return success();
}
};
-/*
-Conversion from F32 to F4E2M1 according to the OCP Spec:
-www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
-
-The spec requiers us to perform Round to Nearest, Ties to Even.
-
-This means that after rounding, we should break ties by choosing the option
-which results in a mantissa of 0 in the least significant digit.
-
-Table of representable values in F4E2M1:
-
-Note: x is sign bit
-| Binary | Value ( + / - )
-| x000 | 0.0
-| x001 | 0.5
-| x010 | 1.0
-| x011 | 1.5
-| x100 | 2.0
-| x101 | 3.0
-| x110 | 4.0
-| x111 | 6.0
-
-Conversion procedure:
-
-Step 1: Clamp to max f4 value
-
-Step 2: convert exponent, if signed int comparison <= 0, set 0
-
-Step 3: if mantissa[1:] greater than 1000000, add 1
-
-Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
-Create bias adjusted exponent, E_1 <- E_0 - 126
-If E_0 <= 0111 1110
- M_1 <- 0, E_1 <- 00
- end
-if E_1 == 00 (special case for almost subnormal)
- if we must round up (M_0 >= 10000000000000000000000)
- M_1 <- 0
- E_1 <- 01
- else
- M_1 <- 1
- end
-Else if E_1 > 00
- roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
- if roundToEven
- M_1 <- 0
- else
- M_1 <- 1
- If M_0 >= 11000000000000000000000
- increment E_1
- If E_1 > 11 (saturate if beyond range)
- M_1 <- 1, E_1 <- 11
-end
-*/
+
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - )
+/// | x000 | 0.0
+/// | x001 | 0.5
+/// | x010 | 1.0
+/// | x011 | 1.5
+/// | x100 | 2.0
+/// | x101 | 3.0
+/// | x110 | 4.0
+/// | x111 | 6.0
+///
+/// Conversion procedure:
+/// Step 1: Clamp to representable bounds.
+/// Step 2: Convert exponent by adjusting bias.
+/// Step 3: Set mantissa to first bit.
+/// Step 4: Special consideration for subnormal and zero exponent.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -504,122 +427,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
- Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-
- Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
- Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
- Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
- Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
- Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
- Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
- Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
- Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
- Value cF32MantissaWidth = c0x00000017; // 23
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
Value cF4MantissaWidth = c0x1; // 1
- Value cF32SignExpWidth = c0x00000009; // 9
- Value cF32FirstBitMask = c0x00400000;
- Value cF32Last22BitMask = c0x003fffff;
- Value cF32MantissaMask = c0x007fffff;
-
+ Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+ Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+ Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+ Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+
// Step 1: Clamp to bounds.
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
- Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
- Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
- Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
- operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+ Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
+ operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
// Step 2: Convert exponent by adjusting bias.
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value biasAdjustment = c0x0000007e; // 126
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
- // Step 0: Special consideration for conversion to 0.5.
- Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
- Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
- Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
- Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-
// Step 3: Set mantissa to first bit.
- Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
- Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+ Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+ man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+ // Step 4: Special consideration for conversion to 0.5.
+ Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+ Value isSubnormal =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ Value isNegOneExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+ Value isNonZeroMan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+ Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+ Value isZeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+
+ Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+ subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+ f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
- // Step 4: Round up if necessary.
- Value cRound = c0x00200000; // 010 0000...
+ // Step 5: Round up if necessary.
+ Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+ shouldRound =
+ b.create<arith::OrIOp>(shouldRound, isSubnormal);
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
- // Regular case
-
-
-
- Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
- Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
- // If we round up or down to even, set mantissa to 0
- Value shouldRoundUp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
- Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
- man23Bits, cRoundDown);
- // dont need to worry about saturation this way
- f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
- Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
- Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
- f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
- f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
-
- // Bordering subnormal
- Value cSubnormalRoundUp =
- createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
- Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
- Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
- Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
- man23Bits, cSubnormalRoundUp);
- f4EdgeRounded =
- b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
- Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
- cSubnormalExp);
-
- // Subnormal
- Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
- Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
- cSubnormalExp);
-
- // create constants to extract mantissa / exponent
- Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
- Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
- // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
- // create constants for NaNs
- Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
- Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
- Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
- Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
- Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
- Value isNan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
- // select for NaNs
- f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result);
- } else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result);
- }
+ Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
rewriter.replaceOp(op, result);
return success();
}
>From 74be7a867fae63df9b1241db2662cb51080b3354 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 10:06:34 -0700
Subject: [PATCH 4/4] PR Review round 1
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 37 +++++++++----------
1 file changed, 18 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index e35154e8ae32b..f42889bb582c5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,8 +34,8 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
-/// Create an float constant.
-static Value createFloatConst(Location loc, Type type, float value,
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, APFloat value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
@@ -416,8 +416,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
- return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
}
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -425,17 +425,11 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
- // Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
- Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
- Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
// Step 1: Clamp to bounds.
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
@@ -443,40 +437,45 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
-
+
// Step 2: Convert exponent by adjusting bias.
- Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
// Step 3: Set mantissa to first bit.
+ Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
// Step 4: Special consideration for conversion to 0.5.
+ Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
Value isNegOneExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
Value isNonZeroMan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
- Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
- Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
Value isZeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
-
+
// Step 5: Round up if necessary.
+ Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
More information about the Mlir-commits
mailing list