[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 00:13:26 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/144157
>From 6ef31a776b020ea5fb34ff6882aba7fe67ff421b Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 9 Jun 2025 21:06:41 +0000
Subject: [PATCH 1/8] initial implementation to fix
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 206 ++++++++++++++++++
1 file changed, 206 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..87483ced8e5cf 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -322,6 +322,57 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ // create constants to extract mantissa / exponent
+ Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+ Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+ // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+ // create constants for NaNs
+ Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+ Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -365,6 +416,161 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return success();
}
};
+/*
+Conversion from F32 to F4E2M1 according to the OCP Spec:
+www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+
+The spec requiers us to perform Round to Nearest, Ties to Even.
+
+This means that after rounding, we should break ties by choosing the option
+which results in a mantissa of 0 in the least significant digit.
+
+Table of representable values in F4E2M1:
+
+Note: x is sign bit
+| Binary | Value ( + / - )
+| x000 | 0.0
+| x001 | 0.5
+| x010 | 1.0
+| x011 | 1.5
+| x100 | 2.0
+| x101 | 3.0
+| x110 | 4.0
+| x111 | 6.0
+
+Conversion procedure:
+Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
+Create bias adjusted exponent, E_1 <- E_0 - 126
+If E_0 <= 0111 1110
+ M_1 <- 0, E_1 <- 00
+ end
+if E_1 == 00 (special case for almost subnormal)
+ if we must round up (M_0 >= 10000000000000000000000)
+ M_1 <- 0
+ E_1 <- 01
+ else
+ M_1 <- 1
+ end
+Else if E_1 > 00
+ roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
+ if roundToEven
+ M_1 <- 0
+ else
+ M_1 <- 1
+ If M_0 >= 11000000000000000000000
+ increment E_1
+ If E_1 > 11 (saturate if beyond range)
+ M_1 <- 1, E_1 <- 11
+end
+*/
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ // Constants
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+ Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+ Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
+ Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+ Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+
+ Value cF32MantissaWidth = c0x00000017; // 23
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32SignExpWidth = c0x00000009; // 9
+ Value cF32MantissaMask = c0x007fffff;
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+ Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+
+ Value cSubnormalExp = c0x7e; // 126
+
+ // Regular case
+ Value biasAdjustment = c0x7e; // 126
+ Value cRoundUp = c0x00600000; // 110 0000...
+ Value cRoundDown = c0x00200000; // 010 0000...
+ Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+ Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
+ Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
+ // If we round up or down to even, set mantissa to 0
+ Value shouldRoundUp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
+ Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
+ man23Bits, cRoundDown);
+ // dont need to worry about saturation this way
+ f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
+ Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+ Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
+ f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
+ f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
+
+ // Bordering subnormal
+ Value cSubnormalRoundUp =
+ createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
+ Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
+ Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
+ Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
+ man23Bits, cSubnormalRoundUp);
+ f4EdgeRounded =
+ b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
+ Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
+ cSubnormalExp);
+
+ // Subnormal
+ Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
+ Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
+ cSubnormalExp);
+
+ // create constants to extract mantissa / exponent
+ Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+ Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+ // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+ // create constants for NaNs
+ Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+ Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
/*
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
>From abc108e9f40adfcb1ac5b30680626ea5f1a35d33 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 12 Jun 2025 16:26:50 +0000
Subject: [PATCH 2/8] intermediate commit
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 79 ++++++++++++++++---
1 file changed, 66 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 87483ced8e5cf..0fb802b82fffb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+ PatternRewriter &rewriter) {
+ auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(shapedTy, attr));
+ }
+
+ return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -439,6 +451,13 @@ Note: x is sign bit
| x111 | 6.0
Conversion procedure:
+
+Step 1: Clamp to max f4 value
+
+Step 2: convert exponent, if signed int comparison <= 0, set 0
+
+Step 3: if mantissa[1:] greater than 1000000, add 1
+
Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
Create bias adjusted exponent, E_1 <- E_0 - 126
If E_0 <= 0111 1110
@@ -485,32 +504,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+ Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+
Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+ Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+ Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
-
- Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
-
Value cF32MantissaWidth = c0x00000017; // 23
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32SignExpWidth = c0x00000009; // 9
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32SignExpWidth = c0x00000009; // 9
+ Value cF32FirstBitMask = c0x00400000;
+ Value cF32Last22BitMask = c0x003fffff;
Value cF32MantissaMask = c0x007fffff;
+
+ // Step 1: Clamp to bounds.
+ Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+ Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+ Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+ Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+ Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+ operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+
+ // Step 2: Convert exponent by adjusting bias.
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
- Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+ Value biasAdjustment = c0x0000007e; // 126
+ Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+ Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+ f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
- Value cSubnormalExp = c0x7e; // 126
+ // Step 0: Special consideration for conversion to 0.5.
+ Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
+ Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
+ Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
+ Value isSubnormal =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+
+ // Step 3: Set mantissa to first bit.
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+ Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+ Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+ Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+ // Step 4: Round up if necessary.
+ Value cRound = c0x00200000; // 010 0000...
+ Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+ Value shouldRound =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+ Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+ f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
// Regular case
- Value biasAdjustment = c0x7e; // 126
- Value cRoundUp = c0x00600000; // 110 0000...
- Value cRoundDown = c0x00200000; // 010 0000...
- Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+
+
+
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
// If we round up or down to even, set mantissa to 0
>From 63f2337c538f463f55104b4470aa88a07d364433 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 13 Jun 2025 13:06:00 -0700
Subject: [PATCH 3/8] Initial implementation of truncf
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 265 +++++-------------
1 file changed, 66 insertions(+), 199 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 0fb802b82fffb..e35154e8ae32b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -334,57 +334,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
-struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const final {
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value operand = op.getOperand();
- Type operandTy = operand.getType();
- Type resultTy = op.getType();
- Type operandETy = getElementTypeOrSelf(operandTy);
- Type resultETy = getElementTypeOrSelf(resultTy);
-
- if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
- return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
- }
-
- Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
- Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
-
- Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
-
- // create constants to extract mantissa / exponent
- Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
- Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
- // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
- // create constants for NaNs
- Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
- Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
- Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
- Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
- Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
- Value isNan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
- // select for NaNs
- f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result);
- } else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -428,60 +377,34 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return success();
}
};
-/*
-Conversion from F32 to F4E2M1 according to the OCP Spec:
-www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
-
-The spec requiers us to perform Round to Nearest, Ties to Even.
-
-This means that after rounding, we should break ties by choosing the option
-which results in a mantissa of 0 in the least significant digit.
-
-Table of representable values in F4E2M1:
-
-Note: x is sign bit
-| Binary | Value ( + / - )
-| x000 | 0.0
-| x001 | 0.5
-| x010 | 1.0
-| x011 | 1.5
-| x100 | 2.0
-| x101 | 3.0
-| x110 | 4.0
-| x111 | 6.0
-
-Conversion procedure:
-
-Step 1: Clamp to max f4 value
-
-Step 2: convert exponent, if signed int comparison <= 0, set 0
-
-Step 3: if mantissa[1:] greater than 1000000, add 1
-
-Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
-Create bias adjusted exponent, E_1 <- E_0 - 126
-If E_0 <= 0111 1110
- M_1 <- 0, E_1 <- 00
- end
-if E_1 == 00 (special case for almost subnormal)
- if we must round up (M_0 >= 10000000000000000000000)
- M_1 <- 0
- E_1 <- 01
- else
- M_1 <- 1
- end
-Else if E_1 > 00
- roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
- if roundToEven
- M_1 <- 0
- else
- M_1 <- 1
- If M_0 >= 11000000000000000000000
- increment E_1
- If E_1 > 11 (saturate if beyond range)
- M_1 <- 1, E_1 <- 11
-end
-*/
+
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - )
+/// | x000 | 0.0
+/// | x001 | 0.5
+/// | x010 | 1.0
+/// | x011 | 1.5
+/// | x100 | 2.0
+/// | x101 | 3.0
+/// | x110 | 4.0
+/// | x111 | 6.0
+///
+/// Conversion procedure:
+/// Step 1: Clamp to representable bounds.
+/// Step 2: Convert exponent by adjusting bias.
+/// Step 3: Set mantissa to first bit.
+/// Step 4: Special consideration for subnormal and zero exponent.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -504,122 +427,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
- Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-
- Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
- Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
- Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
- Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
- Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
- Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
- Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
- Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
- Value cF32MantissaWidth = c0x00000017; // 23
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
Value cF4MantissaWidth = c0x1; // 1
- Value cF32SignExpWidth = c0x00000009; // 9
- Value cF32FirstBitMask = c0x00400000;
- Value cF32Last22BitMask = c0x003fffff;
- Value cF32MantissaMask = c0x007fffff;
-
+ Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+ Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+ Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+ Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+
// Step 1: Clamp to bounds.
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
- Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
- Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
- Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
- operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+ Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
+ operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
// Step 2: Convert exponent by adjusting bias.
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value biasAdjustment = c0x0000007e; // 126
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
- // Step 0: Special consideration for conversion to 0.5.
- Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
- Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
- Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
- Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-
// Step 3: Set mantissa to first bit.
- Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
- Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+ Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+ man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+ // Step 4: Special consideration for conversion to 0.5.
+ Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+ Value isSubnormal =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ Value isNegOneExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+ Value isNonZeroMan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+ Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+ Value isZeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+
+ Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+ subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+ f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
- // Step 4: Round up if necessary.
- Value cRound = c0x00200000; // 010 0000...
+ // Step 5: Round up if necessary.
+ Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+ shouldRound =
+ b.create<arith::OrIOp>(shouldRound, isSubnormal);
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
- // Regular case
-
-
-
- Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
- Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
- // If we round up or down to even, set mantissa to 0
- Value shouldRoundUp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
- Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
- man23Bits, cRoundDown);
- // dont need to worry about saturation this way
- f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
- Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
- Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
- f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
- f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
-
- // Bordering subnormal
- Value cSubnormalRoundUp =
- createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
- Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
- Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
- Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
- man23Bits, cSubnormalRoundUp);
- f4EdgeRounded =
- b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
- Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
- cSubnormalExp);
-
- // Subnormal
- Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
- Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
- cSubnormalExp);
-
- // create constants to extract mantissa / exponent
- Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
- Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
- // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
- // create constants for NaNs
- Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
- Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
- Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
- Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
- Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
- Value isNan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
- // select for NaNs
- f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result);
- } else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result);
- }
+ Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
rewriter.replaceOp(op, result);
return success();
}
>From 30db57067002595e314a2ebe03b21ca20a87843f Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 10:06:34 -0700
Subject: [PATCH 4/8] PR Review round 1
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 51 ++++++++++---------
1 file changed, 26 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index e35154e8ae32b..199f7c6d2a34d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,9 +34,9 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
-/// Create an float constant.
-static Value createFloatConst(Location loc, Type type, float value,
- PatternRewriter &rewriter) {
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, APFloat value,
+ PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
@@ -416,8 +416,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
- return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
}
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -425,58 +425,59 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
- // Constants
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
- Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
- Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
-
+
// Step 1: Clamp to bounds.
- Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
- Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
- Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
- operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
+ Value cHigherBound =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
+ Value cLowerBound =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
+ Value operandClamped = b.create<arith::MinimumFOp>(cLowerBound, operand);
+ operandClamped = b.create<arith::MaximumFOp>(cHigherBound, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
-
+
// Step 2: Convert exponent by adjusting bias.
- Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
// Step 3: Set mantissa to first bit.
+ Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
// Step 4: Special consideration for conversion to 0.5.
+ Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
Value isNegOneExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
Value isNonZeroMan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
- Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
- Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
Value isZeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
-
+
// Step 5: Round up if necessary.
+ Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
>From c4e66f0d0ad1e18c510f3a5194b1a6949ce04fd6 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 12:31:06 -0700
Subject: [PATCH 5/8] adding extf implementation
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 117 ++++++++++++++----
1 file changed, 94 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 199f7c6d2a34d..889f1f17f0d82 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -334,6 +334,70 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
+ !llvm::isa<Float32Type>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+ Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value cZero =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+ Value cHalf =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+ Value mantissaBitmask = c0x1;
+ Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+ Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+ Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+ Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+ f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);
+
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+ Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+ f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+ Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+ f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+ f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);
+
+ Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+ Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+ // Special consideration for subnormal exp (exp == 0).
+ Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ f32ExpBits, biasAdjustment);
+ Value isManSet =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+ Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+ f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
+
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -389,7 +453,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
/// Table of representable values in F4E2M1:
///
/// Note: x is sign bit
-/// | Binary | Value ( + / - )
+/// | Binary | Value ( + / - )
/// | x000 | 0.0
/// | x001 | 0.5
/// | x010 | 1.0
@@ -399,12 +463,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
/// | x110 | 4.0
/// | x111 | 6.0
///
-/// Conversion procedure:
+/// Conversion procedure:
/// Step 1: Clamp to representable bounds.
/// Step 2: Convert exponent by adjusting bias.
/// Step 3: Set mantissa to first bit.
/// Step 4: Special consideration for subnormal and zero exponent.
-/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+/// subnormal.
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -442,48 +507,54 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Step 2: Convert exponent by adjusting bias.
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32MantissaWidth =
+ createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+ Value biasAdjustedSignExp =
+ b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
-
+
// Step 3: Set mantissa to first bit.
- Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value cF32FirstBitMask =
+ createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
-
+
// Step 4: Special consideration for conversion to 0.5.
- Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+ Value cF32MantissaMask =
+ createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
Value isNegOneExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
- Value isNonZeroMan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+ man23Bits, c0x00000000);
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
- Value isZeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-
+ Value isZeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
- Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+ Value subResult =
+ b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
-
+
// Step 5: Round up if necessary.
- Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
- Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+ Value cF32Last22BitMask =
+ createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
+ Value cRound =
+ createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
- shouldRound =
- b.create<arith::OrIOp>(shouldRound, isSubnormal);
+ shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
>From 1d8a9864161a66205e7c0a356efcbddb620e1af9 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 22:28:47 -0700
Subject: [PATCH 6/8] add tests and fix various issues revealed by tests
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 3 +
.../mlir/Dialect/Arith/Transforms/Passes.td | 2 +
.../Dialect/Arith/Transforms/ExpandOps.cpp | 77 +++++--
mlir/test/Dialect/Arith/expand-ops-scale.mlir | 159 ++++++++++++++
mlir/test/Dialect/Arith/expand-ops.mlir | 195 ++++--------------
.../CPU/test-arith-expand-truncf-extf.mlir | 73 +++++++
6 files changed, 333 insertions(+), 176 deletions(-)
create mode 100644 mlir/test/Dialect/Arith/expand-ops-scale.mlir
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index e0a4567d6f406..b03cf2db78041 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
+void populateExpandF4E2M1Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index e14b2aeee1c69..c7370b83fdb6c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -19,6 +19,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
"Enable the BF16 expansion patterns">,
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
"Enable the F8E8M0 expansion patterns">,
+ Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false",
+ "Enable the F4E2M1 expansion patterns">,
];
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 889f1f17f0d82..aef995143112a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -345,9 +345,8 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
- !llvm::isa<Float32Type>(resultETy)) {
- return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
+ if (!isa<Float4E2M1FNType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
}
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -357,8 +356,9 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
Value cZero =
createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
Value cHalf =
@@ -370,29 +370,33 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
- f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);
+ f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
- f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
- f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);
+ Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+ f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
- // Special consideration for subnormal exp (exp == 0).
+ // Special consideration for subnormal exponent (exp == 00).
Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
f32ExpBits, biasAdjustment);
Value isManSet =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
- f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+ if (!isa<Float32Type>(resultETy)) {
+ result = b.create<arith::TruncFOp>(resultETy, operand);
+ }
rewriter.replaceOp(op, result);
return success();
}
@@ -481,8 +485,11 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
- return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
+ if (!isa<Float32Type>(operandETy)) {
+ operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
+ }
+ if (!isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
}
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -491,20 +498,28 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
- // Step 1: Clamp to bounds.
+ // Step 0: Clamp to bounds.
Value cHigherBound =
createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
Value cLowerBound =
createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
- Value operandClamped = b.create<arith::MinimumFOp>(cLowerBound, operand);
- operandClamped = b.create<arith::MaximumFOp>(cHigherBound, operandClamped);
+ Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
+ operandClamped = b.create<arith::MaximumFOp>(cLowerBound, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+ // Step 1: Set sign bit.
+ Value cF32ExpManWidth =
+ createConst(op->getLoc(), i32Ty, 31, rewriter); // 23
+ Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
+ Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
+ Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
+
// Step 2: Convert exponent by adjusting bias.
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
Value cF4MantissaWidth = c0x1; // 1
@@ -513,8 +528,9 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp =
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
- Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
- f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
+ Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+ f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+ f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
// Step 3: Set mantissa to first bit.
Value cF32FirstBitMask =
@@ -522,7 +538,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
- Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+ f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
// Step 4: Special consideration for conversion to 0.5.
Value cF32MantissaMask =
@@ -538,7 +554,6 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
Value isZeroExp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
Value subResult =
@@ -719,16 +734,24 @@ struct ArithExpandOpsPass
if (includeF8E8M0) {
arith::populateExpandF8E8M0Patterns(patterns);
}
+ if (includeF4E2M1) {
+ arith::populateExpandF4E2M1Patterns(patterns);
+ }
target.addDynamicallyLegalOp<arith::ExtFOp>(
[=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if (includeBf16)
+ if (includeBf16) {
legalTypes &= !(inETy.isBF16() && outETy.isF32());
- if (includeF8E8M0)
+ }
+ if (includeF8E8M0) {
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+ }
+ if (includeF4E2M1) {
+ legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
+ }
return legalTypes;
});
@@ -737,10 +760,15 @@ struct ArithExpandOpsPass
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if (includeBf16)
+ if (includeBf16) {
legalTypes &= !(inETy.isF32() && outETy.isBF16());
- if (includeF8E8M0)
+ }
+ if (includeF8E8M0) {
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
+ }
+ if (includeF4E2M1) {
+ legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
+ }
return legalTypes;
});
@@ -765,6 +793,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
+ patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
patterns.getContext());
diff --git a/mlir/test/Dialect/Arith/expand-ops-scale.mlir b/mlir/test/Dialect/Arith/expand-ops-scale.mlir
new file mode 100644
index 0000000000000..0b244e4eed784
--- /dev/null
+++ b/mlir/test/Dialect/Arith/expand-ops-scale.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
+ %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+
+// CHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
+// CHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
+// CHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+ %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+ return %0 : vector<4xf6E3M2FN>
+}
+
+// CHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
+// CHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
+// CHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
+// CHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
+ %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
+ return %0 : vector<4xf6E3M2FN>
+}
+// CHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
+// CHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
+// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
+// CHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
+// CHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+ %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+// CHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
+// CHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
+// CHECK: return
+
+// -----
+func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
+ %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
+ return %0 : vector<4xf4E2M1FN>
+}
+// CHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
+// CHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: return
+
+// -----
+
+func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
+ %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @scaling_extf_to_f32
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
+ %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
+// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
+ // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+ %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f32
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f16
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_bf16
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
+// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+ %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
+// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index db1349feaff3a..62e059ccbe8de 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true include-f4e2m1=true" -verify-diagnostics -split-input-file | FileCheck %s
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -310,64 +309,6 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
// -----
-func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
- %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
- return %0 : f4E2M1FN
-}
-
-// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
-// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
-// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
-// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
- %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
- return %0 : vector<4xf6E3M2FN>
-}
-
-// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
-// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
-// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
-// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
-// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
-
-// -----
-
-func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
- %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
- return %0 : vector<4xf6E3M2FN>
-}
-// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
-// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
-// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
-// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
-// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
-
-// -----
-
-func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
- %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
- return %0 : f4E2M1FN
-}
-// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
-// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
-// SCHECK: return
-
-// -----
-func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
- %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
- return %0 : vector<4xf4E2M1FN>
-}
-// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
-// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: return
-
-// -----
-
func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
// expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
@@ -446,33 +387,6 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
// -----
-func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
- %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
- return %0 : f32
-}
-
-// SCHECK-LABEL: @scaling_extf_to_f32
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
- %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
- return %0 : f32
-}
-
-// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
-// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
-// SCHECK: return %[[RESULT]]
-
-// -----
-
func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
// expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
@@ -481,73 +395,6 @@ func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f
// -----
-func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
- %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
- return %0 : vector<4xf32>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f32
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
- %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
- return %0 : vector<4xf16>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f16
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
- %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
- return %0 : vector<4xbf16>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_bf16
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
- %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
- return %0 : vector<4xf32>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
-// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
- %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
- return %0 : vector<4xf32>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
-// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
func.func @maxsi(%a: i32, %b: i32) -> i32 {
%result = arith.maxsi %a, %b : i32
return %result : i32
@@ -593,3 +440,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @truncf_f32_to_f4E2M1FN(%arg0 : f32) -> f4E2M1FN {
+ %0 = arith.truncf %arg0 : f32 to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+
+// CHECK-LABEL: @truncf_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f32_to_f4E2M1FN(%arg0 : vector<4xf32>) -> vector<4xf4E2M1FN> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf4E2M1FN>
+ return %0 : vector<4xf4E2M1FN>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @extf_f4E2M1FN_to_f32(%arg0 : f4E2M1FN) -> f32 {
+ %0 = arith.extf %arg0 : f4E2M1FN to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f4E2M1FN_to_f32(%arg0 : vector<4xf4E2M1FN>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
new file mode 100644
index 0000000000000..6e76968c70e5f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -0,0 +1,73 @@
+// Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
+
+// RUN: mlir-opt %s --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm \
+// RUN: --arith-expand="include-f4e2m1=true" \
+// RUN: --convert-arith-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+func.func @check_extf(%in : f4E2M1FN) -> () {
+ %res = arith.extf %in : f4E2M1FN to f32
+ vector.print %res : f32
+ return
+}
+
+// See https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+// for details on F4E2M1 representation
+func.func @check_truncf(%in : f32) -> () {
+ %trunc = arith.truncf %in : f32 to f4E2M1FN
+ %bitcast = arith.bitcast %trunc : f4E2M1FN to i4
+ %res = arith.extui %bitcast : i4 to i64
+ vector.print %res : i64
+ return
+}
+
+func.func @entry() {
+ %zero = arith.constant 0.0 : f32
+ %half = arith.constant 0.5 : f32
+ %one = arith.constant 1.0 : f32
+ %max = arith.constant 6.0 : f32
+ %min = arith.constant -6.0 : f32
+ %lowerThanMin = arith.constant -1000000.0 : f32
+ %higherThanMax = arith.constant 1000000.0 : f32
+ %mustRound = arith.constant -3.14 : f32
+ %nan = arith.constant 0x7f80000 : f32
+
+ // CHECK: 0
+ func.call @check_truncf(%zero) : (f32) -> ()
+ // CHECK: 1
+ func.call @check_truncf(%half) : (f32) -> ()
+ // CHECK: 2
+ func.call @check_truncf(%one) : (f32) -> ()
+ // CHECK: 7
+ func.call @check_truncf(%max) : (f32) -> ()
+ // CHECK: 15
+ func.call @check_truncf(%min) : (f32) -> ()
+ // CHECK: 7
+ func.call @check_truncf(%higherThanMax) : (f32) -> ()
+ // CHECK: 15
+ func.call @check_truncf(%lowerThanMin) : (f32) -> ()
+ // CHECK: 13
+ func.call @check_truncf(%mustRound) : (f32) -> ()
+ // CHECK: 0
+ func.call @check_truncf(%nan) : (f32) -> ()
+
+ // CHECK: 0
+ %zeroF4 = arith.truncf %zero : f32 to f4E2M1FN
+ func.call @check_extf(%zeroF4) : (f4E2M1FN) -> ()
+ // CHECK: 0.5
+ %halfF4 = arith.truncf %half : f32 to f4E2M1FN
+ func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+ // CHECK: 6
+ %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
+ func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+ // CHECK: -6
+ %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
+ func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
+ // CHECK: -3
+ %mustRoundF4 = arith.truncf %mustRound : f32 to f4E2M1FN
+ func.call @check_extf(%mustRoundF4) : (f4E2M1FN) -> ()
+ return
+}
>From 19394a89cfdbe082bfbf2a45555fa44c5a348fb3 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 18 Jun 2025 19:28:02 +0000
Subject: [PATCH 7/8] Adding lookup implementation for arith.extf + formatting
fixes
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 126 +++++++++++++-----
.../CPU/test-arith-expand-truncf-extf.mlir | 6 +-
2 files changed, 93 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index aef995143112a..aec2001d64443 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallVectorExtras.h"
namespace mlir {
namespace arith {
@@ -240,9 +241,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!operandETy.isBF16() || !resultETy.isF32()) {
+ if (!operandETy.isBF16() || !resultETy.isF32())
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
- }
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -270,9 +270,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!operandETy.isF32() || !resultETy.isBF16()) {
+ if (!operandETy.isF32() || !resultETy.isBF16())
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
- }
if (op.getRoundingmodeAttr()) {
return rewriter.notifyMatchFailure(
@@ -336,6 +335,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
+ F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -402,6 +403,71 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
}
};
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (isa<ShapedType>(operandTy))
+ return failure();
+
+ if (!isa<Float4E2M1FNType>(operandETy))
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+ SmallVector<int> values = {
+ 0x00000000, // 0.0
+ 0x3f000000, // 0.5
+ 0x3f800000, // 1.0
+ 0x3fc00000, // 1.5
+ 0x40000000, // 2.0
+ 0x40400000, // 3.0
+ 0x40800000, // 4.0
+ 0x40c00000 // 6.0
+ };
+ // auto type = RankedTensorType::get({8}, b.getI32Type());
+ VectorType type = VectorType::get({8}, b.getI32Type());
+ SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
+ values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
+ Value lookupTable = b.create<arith::ConstantOp>(
+ DenseIntElementsAttr::get(type, lookupTableAttr));
+
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
+
+ Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
+ Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
+ Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
+ Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
+
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+ Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+ Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
+ Value signBitI32 = b.create<arith::ExtUIOp>(i32Ty, signBitI4);
+ signBitI32 = b.create<arith::ShLIOp>(signBitI32, c0x0000001c);
+
+ Value unsignedBits = b.create<vector::ExtractOp>(lookupTable, index);
+ Value f32Bits = b.create<arith::OrIOp>(signBitI32, unsignedBits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (!isa<Float32Type>(resultETy))
+ result = b.create<arith::TruncFOp>(resultETy, operand);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -413,9 +479,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy))
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
- }
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -485,12 +550,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!isa<Float32Type>(operandETy)) {
+ if (!isa<Float32Type>(operandETy))
operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
- }
- if (!isa<Float4E2M1FNType>(resultETy)) {
+ if (!isa<Float4E2M1FNType>(resultETy))
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
- }
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
@@ -509,8 +572,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
Value cLowerBound =
createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
- Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
- operandClamped = b.create<arith::MaximumFOp>(cLowerBound, operandClamped);
+ Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
+ operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
// Step 1: Set sign bit.
@@ -594,14 +657,12 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultTy = op.getType();
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy))
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
- }
- if (op.getRoundingmodeAttr()) {
+ if (op.getRoundingmodeAttr())
return rewriter.notifyMatchFailure(
op, "only applicable to default rounding mode.");
- }
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -711,6 +772,8 @@ struct ArithExpandOpsPass
arith::populateArithExpandOpsPatterns(patterns);
target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<vector::VectorDialect>();
+
// clang-format off
target.addIllegalOp<
arith::CeilDivSIOp,
@@ -728,30 +791,24 @@ struct ArithExpandOpsPass
arith::ScalingTruncFOp
>();
- if (includeBf16) {
+ if (includeBf16)
arith::populateExpandBFloat16Patterns(patterns);
- }
- if (includeF8E8M0) {
+ if (includeF8E8M0)
arith::populateExpandF8E8M0Patterns(patterns);
- }
- if (includeF4E2M1) {
+ if (includeF4E2M1)
arith::populateExpandF4E2M1Patterns(patterns);
- }
target.addDynamicallyLegalOp<arith::ExtFOp>(
[=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if (includeBf16) {
+ if (includeBf16)
legalTypes &= !(inETy.isBF16() && outETy.isF32());
- }
- if (includeF8E8M0) {
+ if (includeF8E8M0)
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
- }
- if (includeF4E2M1) {
+ if (includeF4E2M1)
legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
- }
return legalTypes;
});
@@ -760,15 +817,12 @@ struct ArithExpandOpsPass
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if (includeBf16) {
+ if (includeBf16)
legalTypes &= !(inETy.isF32() && outETy.isBF16());
- }
- if (includeF8E8M0) {
+ if (includeF8E8M0)
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
- }
- if (includeF4E2M1) {
+ if (includeF4E2M1)
legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
- }
return legalTypes;
});
@@ -794,8 +848,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
}
void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
- patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
- patterns.getContext());
+ patterns.add<F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
+ F4E2M1TruncFOpConverter>(patterns.getContext());
}
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 6e76968c70e5f..9c310d80d4c2d 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -1,9 +1,9 @@
// Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
-// RUN: mlir-opt %s --convert-vector-to-llvm \
-// RUN: --convert-func-to-llvm \
+// RUN: mlir-opt %s --convert-func-to-llvm \
// RUN: --arith-expand="include-f4e2m1=true" \
-// RUN: --convert-arith-to-llvm -reconcile-unrealized-casts | \
+// RUN: --convert-arith-to-llvm --convert-vector-to-llvm \
+// RUN: --reconcile-unrealized-casts | \
// RUN: mlir-runner -e entry --entry-point-result=void \
// RUN: --shared-libs=%mlir_c_runner_utils | \
// RUN: FileCheck %s --match-full-lines
>From f554b7c22612190d305964ff416f8a277247359a Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 19 Jun 2025 05:03:27 +0000
Subject: [PATCH 8/8] improving extf implementation
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 250 +++++++-----------
1 file changed, 102 insertions(+), 148 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index aec2001d64443..473980b44b66a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -11,9 +11,11 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVectorExtras.h"
+#include <cstdint>
namespace mlir {
namespace arith {
@@ -333,133 +335,92 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+/// In this implementation of extf we take advantage of some key patterns we
+/// notice between the binary representation of an F4E2M1 value and its
+/// corresponding value in fp32.
+///
+/// Note: x is sign bit
+/// | Binary | F4E2M1 | fp32
+/// | x000 | 0.0 | 0000 0000 00
+/// | x001 | 0.5 | 0011 1111 00
+/// | x010 | 1.0 | 0011 1111 10
+/// | x011 | 1.5 | 0011 1111 11
+/// | x100 | 2.0 | 0010 0000 00
+/// | x101 | 3.0 | 0010 0000 01
+/// | x110 | 4.0 | 0010 0000 10
+/// | x111 | 6.0 | 0010 0000 11
+///
+/// 1) There are only two versions of bits [25:31] in the fp32 result
+/// F4E2M1 bits[2:3] decide whether:
+/// - FP32 bits[25:31] = 0011 1111
+/// - FP32 bits[25:31] = 0010 0000
+/// Exception is zero where
+/// - FP32 bits[25:31] = 0000 0000
+///
+/// 2) F4E2M1 bits[1:2] = FP32 bits[23:24]
+/// Exception is 0.5 where
+/// - F4E2M1 bits[1:2] = 01, FP32 bits[23:24] = 00
+///
+/// 3) F4E2M1 bits[4] = FP32 bits[32] (sign bits are equal)
+///
+/// 4) FP32 bits[1:22] = 0
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
- F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value operand = op.getOperand();
- Type operandTy = operand.getType();
- Type resultTy = op.getType();
- Type operandETy = getElementTypeOrSelf(operandTy);
- Type resultETy = getElementTypeOrSelf(resultTy);
-
- if (!isa<Float4E2M1FNType>(operandETy)) {
- return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
- }
-
- Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
- Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
-
- Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
-
- Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
- Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
- Value cZero =
- createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
- Value cHalf =
- createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
-
- Value mantissaBitmask = c0x1;
- Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
- Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
-
- Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
- Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
- f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
-
- Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
- Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
- f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
- Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
- f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
- Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
- f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
-
- Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
- Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
- f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
- f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
-
- // Special consideration for subnormal exponent (exp == 00).
- Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- f32ExpBits, biasAdjustment);
- Value isManSet =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
- Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
-
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
- if (!isa<Float32Type>(resultETy)) {
- result = b.create<arith::TruncFOp>(resultETy, operand);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
- ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
- : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
- LogicalResult matchAndRewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const final {
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (isa<ShapedType>(operandTy))
- return failure();
-
if (!isa<Float4E2M1FNType>(operandETy))
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
- SmallVector<int> values = {
- 0x00000000, // 0.0
- 0x3f000000, // 0.5
- 0x3f800000, // 1.0
- 0x3fc00000, // 1.5
- 0x40000000, // 2.0
- 0x40400000, // 3.0
- 0x40800000, // 4.0
- 0x40c00000 // 6.0
- };
- // auto type = RankedTensorType::get({8}, b.getI32Type());
- VectorType type = VectorType::get({8}, b.getI32Type());
- SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
- values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
- Value lookupTable = b.create<arith::ConstantOp>(
- DenseIntElementsAttr::get(type, lookupTableAttr));
-
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
-
Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
- Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
- Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
- Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
- Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
-
- Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
- Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
- Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
- Value signBitI32 = b.create<arith::ExtUIOp>(i32Ty, signBitI4);
- signBitI32 = b.create<arith::ShLIOp>(signBitI32, c0x0000001c);
-
- Value unsignedBits = b.create<vector::ExtractOp>(lookupTable, index);
- Value f32Bits = b.create<arith::OrIOp>(signBitI32, unsignedBits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
+ Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
+ Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
+ Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+
+ // Set last Exponent bit and Mantissa.
+ Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
+ Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
+ Value isHalf =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
+ bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
+ bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
+ bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
+
+ // Set first 7 bits of Exponent.
+ Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
+ Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
+ Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
+ Value useLargerExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
+ Value bits25To31 =
+ b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
+ Value zeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
+ bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
+
+ // Set sign.
+ Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
+ Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
+ Value negative =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
+ Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
+
+ // Add segments together.
+ Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
+ Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
+ Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
if (!isa<Float32Type>(resultETy))
result = b.create<arith::TruncFOp>(resultETy, operand);
@@ -522,15 +483,15 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
/// Table of representable values in F4E2M1:
///
/// Note: x is sign bit
-/// | Binary | Value ( + / - )
-/// | x000 | 0.0
-/// | x001 | 0.5
-/// | x010 | 1.0
-/// | x011 | 1.5
-/// | x100 | 2.0
-/// | x101 | 3.0
-/// | x110 | 4.0
-/// | x111 | 6.0
+/// | Binary | F4E2M1 | fp32
+/// | x000 | 0.0 | 0000 0000 00
+/// | x001 | 0.5 | 0011 1111 00
+/// | x010 | 1.0 | 0011 1111 10
+/// | x011 | 1.5 | 0011 1111 11
+/// | x100 | 2.0 | 0010 0000 00
+/// | x101 | 3.0 | 0010 0000 01
+/// | x110 | 4.0 | 0010 0000 10
+/// | x111 | 6.0 | 0010 0000 11
///
/// Conversion procedure:
/// Step 1: Clamp to representable bounds.
@@ -543,7 +504,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const final {
- ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
@@ -560,34 +522,30 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
- Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
- Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
- Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
- Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
- Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
- Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+ Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
+ Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
+ Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
+ Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
+ Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
+ Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
// Step 0: Clamp to bounds.
- Value cHigherBound =
- createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
- Value cLowerBound =
- createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
+ Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
+ Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
// Step 1: Set sign bit.
- Value cF32ExpManWidth =
- createConst(op->getLoc(), i32Ty, 31, rewriter); // 23
+ Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
// Step 2: Convert exponent by adjusting bias.
- Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
- Value cF4MantissaWidth = c0x1; // 1
- Value cF32MantissaWidth =
- createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+ Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp =
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
@@ -596,16 +554,14 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
// Step 3: Set mantissa to first bit.
- Value cF32FirstBitMask =
- createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
// Step 4: Special consideration for conversion to 0.5.
- Value cF32MantissaMask =
- createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+ Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
Value isSubnormal =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
@@ -613,22 +569,20 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
- man23Bits, c0x00000000);
+ man23Bits, zeroExpBits);
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
Value isZeroExp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
- Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
- Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+ Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
Value subResult =
b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
// Step 5: Round up if necessary.
- Value cF32Last22BitMask =
- createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
- Value cRound =
- createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+ Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
+ Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
@@ -848,8 +802,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
}
void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
- patterns.add<F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
- F4E2M1TruncFOpConverter>(patterns.getContext());
+ patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
+ patterns.getContext());
}
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
More information about the Mlir-commits
mailing list