[Mlir-commits] [mlir] 379a609 - [mlir][arith][transforms] Adds f4E2M1FN support to truncf and extf (#144157)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 20 09:27:38 PDT 2025
Author: Muzammil
Date: 2025-06-20T11:27:35-05:00
New Revision: 379a609dadc1733c2b62d2bf3bab6e8032236836
URL: https://github.com/llvm/llvm-project/commit/379a609dadc1733c2b62d2bf3bab6e8032236836
DIFF: https://github.com/llvm/llvm-project/commit/379a609dadc1733c2b62d2bf3bab6e8032236836.diff
LOG: [mlir][arith][transforms] Adds f4E2M1FN support to truncf and extf (#144157)
See work detail: https://github.com/iree-org/iree/issues/20920
Add support for f4E2M1FN in `arith.truncf` and `arith.extf` ops though a software emulation
---------
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
Added:
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
mlir/test/Dialect/Arith/expand-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index e0a4567d6f406..b03cf2db78041 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
+void populateExpandF4E2M1Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index e14b2aeee1c69..c7370b83fdb6c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -19,6 +19,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
"Enable the BF16 expansion patterns">,
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
"Enable the F8E8M0 expansion patterns">,
+ Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false",
+ "Enable the F4E2M1 expansion patterns">,
];
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..dfa01844737c6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -11,8 +11,11 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include <cstdint>
namespace mlir {
namespace arith {
@@ -34,6 +37,18 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, APFloat value,
+ PatternRewriter &rewriter) {
+ auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(shapedTy, attr));
+ }
+
+ return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -322,6 +337,100 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+/// In this implementation of extf we take advantage of some key patterns we
+/// notice between the binary representation of an F4E2M1 value and its
+/// corresponding value in F32.
+///
+/// Note: x is sign bit
+/// | Binary | F4E2M1 | f32[23:32]
+/// | x000 | 0.0 | x000 0000 00
+/// | x001 | 0.5 | x011 1111 00
+/// | x010 | 1.0 | x011 1111 10
+/// | x011 | 1.5 | x011 1111 11
+/// | x100 | 2.0 | x010 0000 00
+/// | x101 | 3.0 | x010 0000 01
+/// | x110 | 4.0 | x010 0000 10
+/// | x111 | 6.0 | x010 0000 11
+///
+/// 1) There are only two versions of bits [25:31] in the f32 result
+/// F4E2M1 bits[2:3] decide whether:
+/// - F32 bits[25:31] = 0011 1111
+/// - F32 bits[25:31] = 0010 0000
+/// Exception is zero where
+/// - F32 bits[25:31] = 0000 0000
+///
+/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
+/// Exception is 0.5 where
+/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
+///
+/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
+///
+/// 4) F32 bits[1:22] = 0
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!isa<Float4E2M1FNType>(operandETy))
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
+ Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
+ Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
+ Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+
+ // Set last Exponent bit and Mantissa.
+ Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
+ Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
+ Value isHalf =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
+ bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
+ bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
+ bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
+
+ // Set first 7 bits of Exponent.
+ Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
+ Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
+ Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
+ Value useLargerExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
+ Value bits25To31 =
+ b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
+ Value zeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
+ bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
+
+ // Set sign.
+ Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
+ Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
+ Value negative =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
+ Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
+
+ // Add segments together.
+ Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
+ Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
+ Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
+ if (!isa<Float32Type>(resultETy))
+ result = b.create<arith::TruncFOp>(resultTy, result);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -366,6 +475,130 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
}
};
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | F4E2M1 | F32[23:32]
+/// | x000 | 0.0 | x000 0000 00
+/// | x001 | 0.5 | x011 1111 00
+/// | x010 | 1.0 | x011 1111 10
+/// | x011 | 1.5 | x011 1111 11
+/// | x100 | 2.0 | x010 0000 00
+/// | x101 | 3.0 | x010 0000 01
+/// | x110 | 4.0 | x010 0000 10
+/// | x111 | 6.0 | x010 0000 11
+///
+/// Conversion procedure:
+/// Step 1: Clamp to representable bounds.
+/// Step 2: Convert exponent by adjusting bias.
+/// Step 3: Set mantissa to first bit.
+/// Step 4: Special consideration for subnormal and zero exponent.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+/// subnormal.
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ if (!isa<Float32Type>(operandETy))
+ operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ if (!isa<Float4E2M1FNType>(resultETy))
+ return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+
+ Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
+ Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
+ Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
+ Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
+ Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
+ Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
+
+ // Step 0: Clamp to bounds.
+ Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
+ Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
+ Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
+ operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+
+ // Step 1: Set sign bit.
+ Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
+ Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
+ Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
+ Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
+
+ // Step 2: Convert exponent by adjusting bias.
+ Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value biasAdjustedSignExp =
+ b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+ Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+ f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+ f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
+
+ // Step 3: Set mantissa to first bit.
+ Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
+ Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+ man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
+ Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+ f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
+
+ // Step 4: Special consideration for conversion to 0.5.
+ Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
+ Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+ Value isSubnormal =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ Value isNegOneExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+ Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+ man23Bits, zeroExpBits);
+ Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+ Value isZeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
+ Value subResult =
+ b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+ subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+ f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+
+ // Step 5: Round up if necessary.
+ Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
+ Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
+ Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+ Value shouldRound =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+ shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
+ Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+ f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+
+ Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
/*
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
@@ -498,6 +731,8 @@ struct ArithExpandOpsPass
arith::populateArithExpandOpsPatterns(patterns);
target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<vector::VectorDialect>();
+
// clang-format off
target.addIllegalOp<
arith::CeilDivSIOp,
@@ -515,22 +750,24 @@ struct ArithExpandOpsPass
arith::ScalingTruncFOp
>();
- if (includeBf16) {
+ if (includeBf16)
arith::populateExpandBFloat16Patterns(patterns);
- }
- if (includeF8E8M0) {
+ if (includeF8E8M0)
arith::populateExpandF8E8M0Patterns(patterns);
- }
+ if (includeF4E2M1)
+ arith::populateExpandF4E2M1Patterns(patterns);
target.addDynamicallyLegalOp<arith::ExtFOp>(
[=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if (includeBf16)
+ if (includeBf16)
legalTypes &= !(inETy.isBF16() && outETy.isF32());
if (includeF8E8M0)
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+ if (includeF4E2M1)
+ legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
return legalTypes;
});
@@ -539,10 +776,12 @@ struct ArithExpandOpsPass
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if (includeBf16)
+ if (includeBf16)
legalTypes &= !(inETy.isF32() && outETy.isBF16());
- if (includeF8E8M0)
+ if (includeF8E8M0)
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
+ if (includeF4E2M1)
+ legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
return legalTypes;
});
@@ -567,6 +806,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
+ patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
patterns.getContext());
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index db1349feaff3a..8f9b0feba442a 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true include-f4e2m1=true" -verify-diagnostics -split-input-file | FileCheck %s
// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
// Test ceil divide with signed integer
@@ -593,3 +593,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @truncf_f32_to_f4E2M1FN(%arg0 : f32) -> f4E2M1FN {
+ %0 = arith.truncf %arg0 : f32 to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+
+// CHECK-LABEL: @truncf_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f32_to_f4E2M1FN(%arg0 : vector<4xf32>) -> vector<4xf4E2M1FN> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf4E2M1FN>
+ return %0 : vector<4xf4E2M1FN>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @extf_f4E2M1FN_to_f32(%arg0 : f4E2M1FN) -> f32 {
+ %0 = arith.extf %arg0 : f4E2M1FN to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f4E2M1FN_to_f32(%arg0 : vector<4xf4E2M1FN>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
new file mode 100644
index 0000000000000..9c310d80d4c2d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -0,0 +1,73 @@
+// Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
+
+// RUN: mlir-opt %s --convert-func-to-llvm \
+// RUN: --arith-expand="include-f4e2m1=true" \
+// RUN: --convert-arith-to-llvm --convert-vector-to-llvm \
+// RUN: --reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+func.func @check_extf(%in : f4E2M1FN) -> () {
+ %res = arith.extf %in : f4E2M1FN to f32
+ vector.print %res : f32
+ return
+}
+
+// See https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+// for details on F4E2M1 representation
+func.func @check_truncf(%in : f32) -> () {
+ %trunc = arith.truncf %in : f32 to f4E2M1FN
+ %bitcast = arith.bitcast %trunc : f4E2M1FN to i4
+ %res = arith.extui %bitcast : i4 to i64
+ vector.print %res : i64
+ return
+}
+
+func.func @entry() {
+ %zero = arith.constant 0.0 : f32
+ %half = arith.constant 0.5 : f32
+ %one = arith.constant 1.0 : f32
+ %max = arith.constant 6.0 : f32
+ %min = arith.constant -6.0 : f32
+ %lowerThanMin = arith.constant -1000000.0 : f32
+ %higherThanMax = arith.constant 1000000.0 : f32
+ %mustRound = arith.constant -3.14 : f32
+ %nan = arith.constant 0x7f80000 : f32
+
+ // CHECK: 0
+ func.call @check_truncf(%zero) : (f32) -> ()
+ // CHECK: 1
+ func.call @check_truncf(%half) : (f32) -> ()
+ // CHECK: 2
+ func.call @check_truncf(%one) : (f32) -> ()
+ // CHECK: 7
+ func.call @check_truncf(%max) : (f32) -> ()
+ // CHECK: 15
+ func.call @check_truncf(%min) : (f32) -> ()
+ // CHECK: 7
+ func.call @check_truncf(%higherThanMax) : (f32) -> ()
+ // CHECK: 15
+ func.call @check_truncf(%lowerThanMin) : (f32) -> ()
+ // CHECK: 13
+ func.call @check_truncf(%mustRound) : (f32) -> ()
+ // CHECK: 0
+ func.call @check_truncf(%nan) : (f32) -> ()
+
+ // CHECK: 0
+ %zeroF4 = arith.truncf %zero : f32 to f4E2M1FN
+ func.call @check_extf(%zeroF4) : (f4E2M1FN) -> ()
+ // CHECK: 0.5
+ %halfF4 = arith.truncf %half : f32 to f4E2M1FN
+ func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+ // CHECK: 6
+ %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
+ func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+ // CHECK: -6
+ %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
+ func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
+ // CHECK: -3
+ %mustRoundF4 = arith.truncf %mustRound : f32 to f4E2M1FN
+ func.call @check_extf(%mustRoundF4) : (f4E2M1FN) -> ()
+ return
+}
More information about the Mlir-commits
mailing list