[Mlir-commits] [mlir] 379a609 - [mlir][arith][transforms] Adds f4E2M1FN support to truncf and extf (#144157)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 09:27:38 PDT 2025


Author: Muzammil
Date: 2025-06-20T11:27:35-05:00
New Revision: 379a609dadc1733c2b62d2bf3bab6e8032236836

URL: https://github.com/llvm/llvm-project/commit/379a609dadc1733c2b62d2bf3bab6e8032236836
DIFF: https://github.com/llvm/llvm-project/commit/379a609dadc1733c2b62d2bf3bab6e8032236836.diff

LOG: [mlir][arith][transforms] Adds f4E2M1FN support to truncf and extf (#144157)

See work detail: https://github.com/iree-org/iree/issues/20920

Add support for f4E2M1FN in `arith.truncf` and `arith.extf` ops though a software emulation

---------

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>

Added: 
    mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index e0a4567d6f406..b03cf2db78041 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
 void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
+void populateExpandF4E2M1Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
 void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index e14b2aeee1c69..c7370b83fdb6c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -19,6 +19,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
               "Enable the BF16 expansion patterns">,
        Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
               "Enable the F8E8M0 expansion patterns">,
+       Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false",
+              "Enable the F4E2M1 expansion patterns">,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..dfa01844737c6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -11,8 +11,11 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include <cstdint>
 
 namespace mlir {
 namespace arith {
@@ -34,6 +37,18 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, APFloat value,
+                              PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
   if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -322,6 +337,100 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+/// In this implementation of extf we take advantage of some key patterns we
+/// notice between the binary representation of an F4E2M1 value and its
+/// corresponding value in F32.
+///
+/// Note: x is sign bit
+/// | Binary | F4E2M1 | f32[23:32]
+/// | x000   | 0.0    | x000 0000 00
+/// | x001   | 0.5    | x011 1111 00
+/// | x010   | 1.0    | x011 1111 10
+/// | x011   | 1.5    | x011 1111 11
+/// | x100   | 2.0    | x010 0000 00
+/// | x101   | 3.0    | x010 0000 01
+/// | x110   | 4.0    | x010 0000 10
+/// | x111   | 6.0    | x010 0000 11
+///
+/// 1) There are only two versions of bits [25:31] in the f32 result
+///     F4E2M1 bits[2:3] decide whether:
+///       - F32 bits[25:31] = 0011 1111
+///       - F32 bits[25:31] = 0010 0000
+///     Exception is zero where
+///       - F32 bits[25:31] = 0000 0000
+///
+/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
+///     Exception is 0.5 where
+///       - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
+///
+/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
+///
+/// 4) F32 bits[1:22] = 0
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!isa<Float4E2M1FNType>(operandETy))
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
+    Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
+    Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
+    Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+
+    // Set last Exponent bit and Mantissa.
+    Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
+    Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
+    Value isHalf =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
+    bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
+    bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
+    bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
+
+    // Set first 7 bits of Exponent.
+    Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
+    Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
+    Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
+    Value useLargerExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
+    Value bits25To31 =
+        b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
+    Value zeroExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
+    bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
+
+    // Set sign.
+    Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
+    Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
+    Value negative =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
+    Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
+
+    // Add segments together.
+    Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
+    Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
+    Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
+    if (!isa<Float32Type>(resultETy))
+      result = b.create<arith::TruncFOp>(resultTy, result);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -366,6 +475,130 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   }
 };
 
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | F4E2M1 | F32[23:32]
+/// | x000   | 0.0    | x000 0000 00
+/// | x001   | 0.5    | x011 1111 00
+/// | x010   | 1.0    | x011 1111 10
+/// | x011   | 1.5    | x011 1111 11
+/// | x100   | 2.0    | x010 0000 00
+/// | x101   | 3.0    | x010 0000 01
+/// | x110   | 4.0    | x010 0000 10
+/// | x111   | 6.0    | x010 0000 11
+///
+/// Conversion procedure:
+///   Step 1: Clamp to representable bounds.
+///   Step 2: Convert exponent by adjusting bias.
+///   Step 3: Set mantissa to first bit.
+///   Step 4: Special consideration for subnormal and zero exponent.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+///   subnormal.
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    if (!isa<Float32Type>(operandETy))
+      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+    if (!isa<Float4E2M1FNType>(resultETy))
+      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+
+    Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
+    Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
+    Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
+    Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
+    Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
+    Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
+
+    // Step 0: Clamp to bounds.
+    Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
+    Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
+    Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
+    operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+
+    // Step 1: Set sign bit.
+    Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
+    Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
+    Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
+    Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
+
+    // Step 2: Convert exponent by adjusting bias.
+    Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
+    Value cF4MantissaWidth = c0x1;                                   // 1
+    Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value biasAdjustedSignExp =
+        b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+    f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+    f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
+
+    // Step 3: Set mantissa to first bit.
+    Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
+    Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
+    Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+    f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
+
+    // Step 4: Special consideration for conversion to 0.5.
+    Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
+    Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+    Value isNegOneExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+                                                 man23Bits, zeroExpBits);
+    Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+    Value isZeroExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+    Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
+    Value subResult =
+        b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+    f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+
+    // Step 5: Round up if necessary.
+    Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
+    Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
+    Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+    Value shouldRound =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
+    Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+
+    Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /*
 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
 Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
@@ -498,6 +731,8 @@ struct ArithExpandOpsPass
     arith::populateArithExpandOpsPatterns(patterns);
 
     target.addLegalDialect<arith::ArithDialect>();
+    target.addLegalDialect<vector::VectorDialect>();
+
     // clang-format off
     target.addIllegalOp<
       arith::CeilDivSIOp,
@@ -515,22 +750,24 @@ struct ArithExpandOpsPass
       arith::ScalingTruncFOp
     >();
 
-    if (includeBf16) {
+    if (includeBf16)
       arith::populateExpandBFloat16Patterns(patterns);
-    }
-    if (includeF8E8M0) {
+    if (includeF8E8M0)
       arith::populateExpandF8E8M0Patterns(patterns);
-    }
+    if (includeF4E2M1)
+      arith::populateExpandF4E2M1Patterns(patterns);
 
     target.addDynamicallyLegalOp<arith::ExtFOp>(
       [=](arith::ExtFOp op) {
         Type inETy = getElementTypeOrSelf(op.getOperand().getType());
         Type outETy = getElementTypeOrSelf(op.getType());
         bool legalTypes = true;
-        if (includeBf16) 
+        if (includeBf16)
           legalTypes &= !(inETy.isBF16() && outETy.isF32());
         if (includeF8E8M0)
           legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+        if (includeF4E2M1)
+          legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
         return legalTypes;
       });
 
@@ -539,10 +776,12 @@ struct ArithExpandOpsPass
         Type inETy = getElementTypeOrSelf(op.getOperand().getType());
         Type outETy = getElementTypeOrSelf(op.getType());
         bool legalTypes = true;
-        if (includeBf16) 
+        if (includeBf16)
           legalTypes &= !(inETy.isF32() && outETy.isBF16());
-        if (includeF8E8M0) 
+        if (includeF8E8M0)
           legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); 
+        if (includeF4E2M1)
+          legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
         return legalTypes;
       });
 
@@ -567,6 +806,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
+  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
   patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
       patterns.getContext());

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index db1349feaff3a..8f9b0feba442a 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s 
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true include-f4e2m1=true" -verify-diagnostics -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
 
 // Test ceil divide with signed integer
@@ -593,3 +593,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
 // CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @truncf_f32_to_f4E2M1FN(%arg0 : f32) -> f4E2M1FN {
+    %0 = arith.truncf %arg0 : f32 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// CHECK-LABEL: @truncf_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f32_to_f4E2M1FN(%arg0 : vector<4xf32>) -> vector<4xf4E2M1FN> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf4E2M1FN>
+    return %0 : vector<4xf4E2M1FN>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @extf_f4E2M1FN_to_f32(%arg0 : f4E2M1FN) -> f32 {
+    %0 = arith.extf %arg0 : f4E2M1FN to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f4E2M1FN_to_f32(%arg0 : vector<4xf4E2M1FN>) -> vector<4xf32> {
+    %0 = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf

diff  --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
new file mode 100644
index 0000000000000..9c310d80d4c2d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -0,0 +1,73 @@
+// Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
+
+// RUN: mlir-opt %s --convert-func-to-llvm \
+// RUN:             --arith-expand="include-f4e2m1=true" \
+// RUN:             --convert-arith-to-llvm --convert-vector-to-llvm \
+// RUN:             --reconcile-unrealized-casts | \
+// RUN:   mlir-runner -e entry --entry-point-result=void \
+// RUN:               --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+func.func @check_extf(%in : f4E2M1FN) -> () {
+  %res = arith.extf %in : f4E2M1FN to f32
+  vector.print %res : f32
+  return
+}
+
+// See https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+// for details on F4E2M1 representation 
+func.func @check_truncf(%in : f32) -> () {
+  %trunc = arith.truncf %in : f32 to f4E2M1FN
+  %bitcast = arith.bitcast %trunc : f4E2M1FN to i4
+  %res = arith.extui %bitcast : i4 to i64
+  vector.print %res : i64
+  return
+}
+
+func.func @entry() {
+  %zero = arith.constant 0.0 : f32
+  %half = arith.constant 0.5 : f32
+  %one = arith.constant 1.0 : f32
+  %max = arith.constant 6.0 : f32
+  %min = arith.constant -6.0 : f32
+  %lowerThanMin = arith.constant -1000000.0 : f32
+  %higherThanMax = arith.constant 1000000.0 : f32
+  %mustRound = arith.constant -3.14 : f32
+  %nan = arith.constant 0x7f80000 : f32
+
+  // CHECK: 0
+  func.call @check_truncf(%zero) : (f32) -> ()
+  // CHECK: 1
+  func.call @check_truncf(%half) : (f32) -> ()
+  // CHECK: 2
+  func.call @check_truncf(%one) : (f32) -> ()
+  // CHECK: 7
+  func.call @check_truncf(%max) : (f32) -> ()
+  // CHECK: 15
+  func.call @check_truncf(%min) : (f32) -> ()
+  // CHECK: 7
+  func.call @check_truncf(%higherThanMax) : (f32) -> ()
+  // CHECK: 15
+  func.call @check_truncf(%lowerThanMin) : (f32) -> ()
+  // CHECK: 13
+  func.call @check_truncf(%mustRound) : (f32) -> ()
+  // CHECK: 0
+  func.call @check_truncf(%nan) : (f32) -> ()
+
+  // CHECK: 0
+  %zeroF4 = arith.truncf %zero : f32 to f4E2M1FN
+  func.call @check_extf(%zeroF4) : (f4E2M1FN) -> ()
+  // CHECK: 0.5
+  %halfF4 = arith.truncf %half : f32 to f4E2M1FN
+  func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+  // CHECK: 6
+  %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
+  func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+  // CHECK: -6
+  %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
+  func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
+  // CHECK: -3
+  %mustRoundF4 = arith.truncf %mustRound : f32 to f4E2M1FN
+  func.call @check_extf(%mustRoundF4) : (f4E2M1FN) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list