[Mlir-commits] [mlir] fe355a4 - [MLIR][Math] Add support for f64 in the expansion of math.roundeven

Alexander Shaposhnikov llvmlistbot at llvm.org
Thu Aug 24 14:41:55 PDT 2023


Author: Alexander Shaposhnikov
Date: 2023-08-24T21:41:26Z
New Revision: fe355a44e7094a1a213a27b89d01d06243620c24

URL: https://github.com/llvm/llvm-project/commit/fe355a44e7094a1a213a27b89d01d06243620c24
DIFF: https://github.com/llvm/llvm-project/commit/fe355a44e7094a1a213a27b89d01d06243620c24.diff

LOG: [MLIR][Math] Add support for f64 in the expansion of math.roundeven

Add support for f64 in the expansion of math.roundeven.
Associated GitHub issue: https://github.com/openxla/iree/issues/13522
This is based on the offline discussion and essentially recommits
https://reviews.llvm.org/D158234.

Test plan: ninja check-mlir check-all

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/include/mlir/ExecutionEngine/Float16bits.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/CMakeLists.txt
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lit.cfg.py
    mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 7b7e894421b407..e7798b2136af07 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -469,8 +469,6 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
-extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits);  // bits!
-extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits!
 
 //===----------------------------------------------------------------------===//
 // Small runtime support library for timing execution and printing GFLOPS

diff  --git a/mlir/include/mlir/ExecutionEngine/Float16bits.h b/mlir/include/mlir/ExecutionEngine/Float16bits.h
index 6bf1589aa13a2a..5eb1f2ce07639d 100644
--- a/mlir/include/mlir/ExecutionEngine/Float16bits.h
+++ b/mlir/include/mlir/ExecutionEngine/Float16bits.h
@@ -48,5 +48,8 @@ MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const f16 &f);
 // Outputs a bfloat value.
 MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const bf16 &d);
 
+extern "C" MLIR_FLOAT16_EXPORT void printF16(uint16_t bits);
+extern "C" MLIR_FLOAT16_EXPORT void printBF16(uint16_t bits);
+
 #undef MLIR_FLOAT16_EXPORT
 #endif // MLIR_EXECUTIONENGINE_FLOAT16BITS_H_

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index de363fc6370c2c..ce68fc2673dcaf 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   unsigned getWidth();
 
   /// Return the width of the mantissa of this type.
+  /// The width includes the integer bit.
   unsigned getFPMantissaWidth();
 
   /// Get or create a new FloatType with bitwidth scaled by `scale`.

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index ee8f23cf362b62..aa5fd1db528e69 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -305,31 +305,40 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
   Type operandETy = getElementTypeOrSelf(operandTy);
   Type resultETy = getElementTypeOrSelf(resultTy);
 
-  if (!operandETy.isF32() || !resultETy.isF32()) {
-    return rewriter.notifyMatchFailure(op, "not a roundeven of f32.");
+  if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
+    return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
   }
 
-  Type i32Ty = b.getI32Type();
-  Type f32Ty = b.getF32Type();
-  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-    i32Ty = shapedTy.clone(i32Ty);
-    f32Ty = shapedTy.clone(f32Ty);
+  Type fTy = operandTy;
+  Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
+  if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
+    iTy = shapedTy.clone(iTy);
   }
 
-  Value c1Float = createFloatConst(loc, f32Ty, 1.0, b);
-  Value c0 = createIntConst(loc, i32Ty, 0, b);
-  Value c1 = createIntConst(loc, i32Ty, 1, b);
-  Value cNeg1 = createIntConst(loc, i32Ty, -1, b);
-  Value c23 = createIntConst(loc, i32Ty, 23, b);
-  Value c31 = createIntConst(loc, i32Ty, 31, b);
-  Value c127 = createIntConst(loc, i32Ty, 127, b);
-  Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b);
-  Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b);
-  Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
-
-  Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
+  // The width returned by getFPMantissaWidth includes the integer bit.
+  unsigned mantissaWidth =
+      llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+
+  // The names of the variables correspond to f32.
+  // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
+  // f32: 1 bit sign | 8 bits exponent  | 23 bits mantissa.
+  // f16: 1 bit sign | 5 bits exponent  | 10 bits mantissa.
+  Value c1Float = createFloatConst(loc, fTy, 1.0, b);
+  Value c0 = createIntConst(loc, iTy, 0, b);
+  Value c1 = createIntConst(loc, iTy, 1, b);
+  Value cNeg1 = createIntConst(loc, iTy, -1, b);
+  Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
+  Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
+  Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
+  Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
+  Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
+  Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
+
+  Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
   Value round = b.create<math::RoundOp>(operand);
-  Value roundBitcast = b.create<arith::BitcastOp>(i32Ty, round);
+  Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
 
   // Get biased exponents for operand and round(operand)
   Value operandExp = b.create<arith::AndIOp>(
@@ -340,7 +349,7 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
   Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
 
   auto safeShiftRight = [&](Value x, Value shift) -> Value {
-    // Clamp shift to valid range [0, 31] to avoid undefined behavior
+    // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
     Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
     clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
     return b.create<arith::ShRUIOp>(x, clampedShift);

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index b28c7cf140fa3b..66a9cb01106ba5 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -119,6 +119,7 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH})
     mlir-capi-execution-engine-test
     mlir_c_runner_utils
     mlir_runner_utils
+    mlir_float16_utils
   )
 endif()
 

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c28e2141db061a..51821e3f099a0c 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -232,8 +232,91 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
 
 // -----
 
-// CHECK-LABEL:   func.func @roundeven
-func.func @roundeven(%arg: f32) -> f32 {
+// CHECK-LABEL:   func.func @roundeven64
+func.func @roundeven64(%arg: f64) -> f64 {
+  %res = math.roundeven %arg : f64
+  return %res : f64
+}
+
+// CHECK-SAME:                   %[[VAL_0:.*]]: f64) -> f64 {
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i64
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i64
+// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i64
+// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG: %[[C_52:.*]] = arith.constant 52 : i64
+// CHECK-DAG: %[[C_63:.*]] = arith.constant 63 : i64
+// CHECK-DAG: %[[C_1023:.*]] = arith.constant 1023 : i64
+// CHECK-DAG: %[[C_2251799813685248:.*]] = arith.constant 2251799813685248 : i64
+// CHECK-DAG: %[[C_4503599627370495:.*]] = arith.constant 4503599627370495 : i64
+// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 2047 : i64
+// CHECK:     %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f64 to i64
+// CHECK:     %[[ROUND:.*]] = math.round %[[VAL_0]] : f64
+// CHECK:     %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f64 to i64
+
+// Get biased exponents of `round` and `operand`
+// CHECK:     %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_52]] : i64
+// CHECK:     %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i64
+// CHECK:     %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_1023]] : i64
+// CHECK:     %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_52]] : i64
+// CHECK:     %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i64
+// CHECK:     %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_1023]] : i64
+
+// Determine if `ROUND_BITCAST` is an even whole number or a special value
+// +-inf, +-nan.
+//   Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
+//   `ROUND_BIASED_EXP - 1`
+//   CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i64
+//   CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i64
+//   CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_63]] : i64
+//   CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_4503599627370495]], %[[CLAMPED_SHIFT_1]] : i64
+//   CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i64
+
+//   `ROUND_BITCAST` is not even whole number or special value if masked
+//   mantissa is != 0 or `ROUND_BIASED_EXP == 0`
+//   CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i64
+//   CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i64
+//   CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
+
+// Determine if operand is halfway between two integer values
+// CHECK:     %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i64
+// CHECK:     %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i64
+// CHECK:     %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_63]] : i64
+// CHECK:     %[[SHIFTED_2_TO_9:.*]] = arith.shrui %[[C_2251799813685248]], %[[CLAMPED_SHIFT_3]] : i64
+
+//   CHECK:     %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_9]] : i64
+
+//   Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
+//   `OPERAND_BIASED_EXP`
+//   CHECK:     %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i64
+//   CHECK:     %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_63]] : i64
+//   CHECK:     %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_4503599627370495]], %[[CLAMPED_SHIFT_5]] : i64
+//   CHECK:     %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i64
+
+//   The operand is halfway between two integers if the masked mantissa is equal
+//   to the expected mantissa and the biased exponent is in the range
+//   [-1,  52).
+//   CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i64
+//   CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_10:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_52]] : i64
+//   CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i64
+//   CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_10]] : i1
+//   CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
+
+// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
+// case where `round` rounded in the oppositve direction of `roundeven`.
+// CHECK:     %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f64
+// CHECK:     %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f64
+// CHECK:     %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
+// CHECK:     %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f64
+
+// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
+// CHECK:     %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f64
+
+// CHECK: return %[[COPYSIGN]] : f64
+
+// -----
+
+// CHECK-LABEL:   func.func @roundeven32
+func.func @roundeven32(%arg: f32) -> f32 {
   %res = math.roundeven %arg : f32
   return %res : f32
 }
@@ -331,3 +414,90 @@ func.func @roundeven(%arg: f32) -> f32 {
 // CHECK:     %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32
 
 // CHECK: return %[[COPYSIGN]] : f32
+
+// -----
+
+// CHECK-LABEL:   func.func @roundeven16
+func.func @roundeven16(%arg: f16) -> f16 {
+  %res = math.roundeven %arg : f16
+  return %res : f16
+}
+
+// CHECK-SAME:                   %[[VAL_0:.*]]: f16) -> f16 {
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i16
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i16
+// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i16
+// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f16
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i16
+// CHECK-DAG: %[[C_15:.*]] = arith.constant 15 : i16
+// CHECK-DAG: %[[C_512:.*]] = arith.constant 512 : i16
+// CHECK-DAG: %[[C_1023:.*]] = arith.constant 1023 : i16
+// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 31 : i16
+
+// CHECK:     %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f16 to i16
+// CHECK:     %[[ROUND:.*]] = math.round %[[VAL_0]] : f16
+// CHECK:     %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f16 to i16
+
+// Get biased exponents of `round` and `operand`
+// CHECK:     %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_10]] : i16
+// CHECK:     %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i16
+// CHECK:     %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_15]] : i16
+// CHECK:     %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_10]] : i16
+// CHECK:     %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i16
+// CHECK:     %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_15]] : i16
+
+// Determine if `ROUND_BITCAST` is an even whole number or a special value
+// +-inf, +-nan.
+//   Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
+//   `ROUND_BIASED_EXP - 1`
+//   CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i16
+//   CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i16
+//   CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_15]] : i16
+//   CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_1]] : i16
+//   CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i16
+
+//   `ROUND_BITCAST` is not even whole number or special value if masked
+//   mantissa is != 0 or `ROUND_BIASED_EXP == 0`
+//   CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i16
+//   CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i16
+//   CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
+
+// Determine if operand is halfway between two integer values
+// CHECK:     %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
+// CHECK:     %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
+// CHECK:     %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_15]] : i16
+// CHECK:     %[[SHIFTED_2_TO_9:.*]] = arith.shrui %[[C_512]], %[[CLAMPED_SHIFT_3]] : i16
+
+//   A value with `0 <= BIASED_EXP < 10` is halfway between two consecutive
+//   integers if the bit at index `BIASED_EXP` starting from the left in the
+//   mantissa is 1 and all the bits to the right are zero. For the case where
+//   `BIASED_EXP == -1, the expected mantissa is all zeros.
+//   CHECK:     %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_9]] : i16
+
+//   Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
+//   `OPERAND_BIASED_EXP`
+//   CHECK:     %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
+//   CHECK:     %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_15]] : i16
+//   CHECK:     %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_5]] : i16
+//   CHECK:     %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i16
+
+//   The operand is halfway between two integers if the masked mantissa is equal
+//   to the expected mantissa and the biased exponent is in the range
+//   [-1,  23).
+//   CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
+//   CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_10:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_10]] : i16
+//   CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i16
+//   CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_10]] : i1
+//   CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
+
+// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
+// case where `round` rounded in the oppositve direction of `roundeven`.
+// CHECK:     %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f16
+// CHECK:     %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f16
+// CHECK:     %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
+// CHECK:     %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f16
+
+// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
+// CHECK:     %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
+
+// CHECK: return %[[COPYSIGN]] : f16

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index fb99422c3ff5a3..f265ac794c6f6d 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -98,6 +98,7 @@ def add_runtime(name):
     add_runtime("mlir_runner_utils"),
     add_runtime("mlir_c_runner_utils"),
     add_runtime("mlir_async_runtime"),
+    add_runtime("mlir_float16_utils"),
     "mlir-linalg-ods-yaml-gen",
     "mlir-reduce",
     "mlir-pdll",

diff  --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index d6943e5fc2831d..6ca25edef59e79 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -3,6 +3,7 @@
 // RUN:     -e main -entry-point-result=void -O0                               \
 // RUN:     -shared-libs=%mlir_c_runner_utils  \
 // RUN:     -shared-libs=%mlir_runner_utils    \
+// RUN:     -shared-libs=%mlir_float16_utils   \
 // RUN: | FileCheck %s
 
 // -------------------------------------------------------------------------- //
@@ -243,26 +244,26 @@ func.func @powf() {
 // roundeven.
 // -------------------------------------------------------------------------- //
 
-func.func @func_roundeven(%a : f32) {
+func.func @func_roundeven32(%a : f32) {
   %b = math.roundeven %a : f32
   vector.print %b : f32
   return
 }
 
-func.func @func_roundeven$bitcast_result_to_int(%a : f32) {
+func.func @func_roundeven32$bitcast_result_to_int(%a : f32) {
   %b = math.roundeven %a : f32
   %c = arith.bitcast %b : f32 to i32
   vector.print %c : i32
   return
 }
 
-func.func @func_roundeven$vector(%a : vector<1xf32>) {
+func.func @func_roundeven32$vector(%a : vector<1xf32>) {
   %b = math.roundeven %a : vector<1xf32>
   vector.print %b : vector<1xf32>
   return
 }
 
-func.func @roundeven() {
+func.func @roundeven32() {
   %c0_25 = arith.constant 0.25 : f32
   %c0_5 = arith.constant 0.5 : f32
   %c0_75 = arith.constant 0.75 : f32
@@ -296,67 +297,66 @@ func.func @roundeven() {
   %cNeg3_75 = arith.constant -3.75 : f32
 
   // CHECK-NEXT: 0
-  call @func_roundeven(%c0_25) : (f32) -> ()
+  call @func_roundeven32(%c0_25) : (f32) -> ()
   // CHECK-NEXT: 0
-  call @func_roundeven(%c0_5) : (f32) -> ()
+  call @func_roundeven32(%c0_5) : (f32) -> ()
   // CHECK-NEXT: 1
-  call @func_roundeven(%c0_75) : (f32) -> ()
+  call @func_roundeven32(%c0_75) : (f32) -> ()
   // CHECK-NEXT: 1
-  call @func_roundeven(%c1) : (f32) -> ()
+  call @func_roundeven32(%c1) : (f32) -> ()
   // CHECK-NEXT: 1
-  call @func_roundeven(%c1_25) : (f32) -> ()
+  call @func_roundeven32(%c1_25) : (f32) -> ()
   // CHECK-NEXT: 2
-  call @func_roundeven(%c1_5) : (f32) -> ()
+  call @func_roundeven32(%c1_5) : (f32) -> ()
   // CHECK-NEXT: 2
-  call @func_roundeven(%c1_75) : (f32) -> ()
+  call @func_roundeven32(%c1_75) : (f32) -> ()
   // CHECK-NEXT: 2
-  call @func_roundeven(%c2) : (f32) -> ()
+  call @func_roundeven32(%c2) : (f32) -> ()
   // CHECK-NEXT: 2
-  call @func_roundeven(%c2_25) : (f32) -> ()
+  call @func_roundeven32(%c2_25) : (f32) -> ()
   // CHECK-NEXT: 2
-  call @func_roundeven(%c2_5) : (f32) -> ()
+  call @func_roundeven32(%c2_5) : (f32) -> ()
   // CHECK-NEXT: 3
-  call @func_roundeven(%c2_75) : (f32) -> ()
+  call @func_roundeven32(%c2_75) : (f32) -> ()
   // CHECK-NEXT: 3
-  call @func_roundeven(%c3) : (f32) -> ()
+  call @func_roundeven32(%c3) : (f32) -> ()
   // CHECK-NEXT: 3
-  call @func_roundeven(%c3_25) : (f32) -> ()
+  call @func_roundeven32(%c3_25) : (f32) -> ()
   // CHECK-NEXT: 4
-  call @func_roundeven(%c3_5) : (f32) -> ()
+  call @func_roundeven32(%c3_5) : (f32) -> ()
   // CHECK-NEXT: 4
-  call @func_roundeven(%c3_75) : (f32) -> ()
+  call @func_roundeven32(%c3_75) : (f32) -> ()
 
   // CHECK-NEXT: -0
-  call @func_roundeven(%cNeg0_25) : (f32) -> ()
+  call @func_roundeven32(%cNeg0_25) : (f32) -> ()
   // CHECK-NEXT: -0
-  call @func_roundeven(%cNeg0_5) : (f32) -> ()
+  call @func_roundeven32(%cNeg0_5) : (f32) -> ()
   // CHECK-NEXT: -1
-  call @func_roundeven(%cNeg0_75) : (f32) -> ()
+  call @func_roundeven32(%cNeg0_75) : (f32) -> ()
   // CHECK-NEXT: -1
-  call @func_roundeven(%cNeg1) : (f32) -> ()
+  call @func_roundeven32(%cNeg1) : (f32) -> ()
   // CHECK-NEXT: -1
-  call @func_roundeven(%cNeg1_25) : (f32) -> ()
+  call @func_roundeven32(%cNeg1_25) : (f32) -> ()
   // CHECK-NEXT: -2
-  call @func_roundeven(%cNeg1_5) : (f32) -> ()
+  call @func_roundeven32(%cNeg1_5) : (f32) -> ()
   // CHECK-NEXT: -2
-  call @func_roundeven(%cNeg1_75) : (f32) -> ()
+  call @func_roundeven32(%cNeg1_75) : (f32) -> ()
   // CHECK-NEXT: -2
-  call @func_roundeven(%cNeg2) : (f32) -> ()
+  call @func_roundeven32(%cNeg2) : (f32) -> ()
   // CHECK-NEXT: -2
-  call @func_roundeven(%cNeg2_25) : (f32) -> ()
+  call @func_roundeven32(%cNeg2_25) : (f32) -> ()
   // CHECK-NEXT: -2
-  call @func_roundeven(%cNeg2_5) : (f32) -> ()
+  call @func_roundeven32(%cNeg2_5) : (f32) -> ()
   // CHECK-NEXT: -3
-  call @func_roundeven(%cNeg2_75) : (f32) -> ()
+  call @func_roundeven32(%cNeg2_75) : (f32) -> ()
   // CHECK-NEXT: -3
-  call @func_roundeven(%cNeg3) : (f32) -> ()
+  call @func_roundeven32(%cNeg3) : (f32) -> ()
   // CHECK-NEXT: -3
-  call @func_roundeven(%cNeg3_25) : (f32) -> ()
+  call @func_roundeven32(%cNeg3_25) : (f32) -> ()
   // CHECK-NEXT: -4
-  call @func_roundeven(%cNeg3_5) : (f32) -> ()
+  call @func_roundeven32(%cNeg3_5) : (f32) -> ()
   // CHECK-NEXT: -4
-  call @func_roundeven(%cNeg3_75) : (f32) -> ()
-
+  call @func_roundeven32(%cNeg3_75) : (f32) -> ()
 
   // Special values: 0, -0, inf, -inf, nan, -nan
   %cNeg0 = arith.constant -0.0 : f32
@@ -371,22 +371,22 @@ func.func @roundeven() {
   %cNegNan = arith.bitcast %cNegNanInt : i32 to f32
 
   // CHECK-NEXT: -0
-  call @func_roundeven(%cNeg0) : (f32) -> ()
+  call @func_roundeven32(%cNeg0) : (f32) -> ()
   // CHECK-NEXT: 0
-  call @func_roundeven(%c0) : (f32) -> ()
+  call @func_roundeven32(%c0) : (f32) -> ()
   // CHECK-NEXT: inf
-  call @func_roundeven(%cInf) : (f32) -> ()
+  call @func_roundeven32(%cInf) : (f32) -> ()
   // CHECK-NEXT: -inf
-  call @func_roundeven(%cNegInf) : (f32) -> ()
+  call @func_roundeven32(%cNegInf) : (f32) -> ()
   // Per IEEE 754-2008, sign is not required when printing a negative NaN, so
   // print as an int to ensure input NaN is left unchanged.
   // CHECK-NEXT: 2143289344
   // CHECK-NEXT: 2143289344
-  call @func_roundeven$bitcast_result_to_int(%cNan) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%cNan) : (f32) -> ()
   vector.print %cNanInt : i32
   // CHECK-NEXT: -4194304
   // CHECK-NEXT: -4194304
-  call @func_roundeven$bitcast_result_to_int(%cNegNan) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%cNegNan) : (f32) -> ()
   vector.print %cNegNanInt : i32
 
 
@@ -402,25 +402,199 @@ func.func @roundeven() {
 
   // CHECK-NEXT: 1258291196
   // hex: 0x4AFFFFFC
-  call @func_roundeven$bitcast_result_to_int(%c8388606_5) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%c8388606_5) : (f32) -> ()
   // CHECK-NEXT: 1258291198
   // hex: 0x4AFFFFFE
-  call @func_roundeven$bitcast_result_to_int(%c8388607) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%c8388607) : (f32) -> ()
   // CHECK-NEXT: 1258291200
   // hex: 0x4B000000
-  call @func_roundeven$bitcast_result_to_int(%c8388607_5) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%c8388607_5) : (f32) -> ()
   // CHECK-NEXT: 1258291200
   // hex: 0x4B000000
-  call @func_roundeven$bitcast_result_to_int(%c8388608) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%c8388608) : (f32) -> ()
   // CHECK-NEXT: 1258291201
   // hex: 0x4B000001
-  call @func_roundeven$bitcast_result_to_int(%c8388609) : (f32) -> ()
+  call @func_roundeven32$bitcast_result_to_int(%c8388609) : (f32) -> ()
 
 
   // Check that vector type works
   %cVec = arith.constant dense<[0.5]> : vector<1xf32>
   // CHECK-NEXT: ( 0 )
-  call @func_roundeven$vector(%cVec) : (vector<1xf32>) -> ()
+  call @func_roundeven32$vector(%cVec) : (vector<1xf32>) -> ()
+  return
+}
+
+func.func @func_roundeven64(%a : f64) {
+  %b = math.roundeven %a : f64
+  vector.print %b : f64
+  return
+}
+
+func.func @func_roundeven64$bitcast_result_to_int(%a : f64) {
+  %b = math.roundeven %a : f64
+  %c = arith.bitcast %b : f64 to i64
+  vector.print %c : i64
+  return
+}
+
+func.func @func_roundeven64$vector(%a : vector<1xf64>) {
+  %b = math.roundeven %a : vector<1xf64>
+  vector.print %b : vector<1xf64>
+  return
+}
+
+func.func @roundeven64() {
+  %c0_25 = arith.constant 0.25 : f64
+  %c0_5 = arith.constant 0.5 : f64
+  %c0_75 = arith.constant 0.75 : f64
+  %c1 = arith.constant 1.0 : f64
+  %c1_25 = arith.constant 1.25 : f64
+  %c1_5 = arith.constant 1.5 : f64
+  %c1_75 = arith.constant 1.75 : f64
+  %c2 = arith.constant 2.0 : f64
+  %c2_25 = arith.constant 2.25 : f64
+  %c2_5 = arith.constant 2.5 : f64
+  %c2_75 = arith.constant 2.75 : f64
+  %c3 = arith.constant 3.0 : f64
+  %c3_25 = arith.constant 3.25 : f64
+  %c3_5 = arith.constant 3.5 : f64
+  %c3_75 = arith.constant 3.75 : f64
+
+  %cNeg0_25 = arith.constant -0.25 : f64
+  %cNeg0_5 = arith.constant -0.5 : f64
+  %cNeg0_75 = arith.constant -0.75 : f64
+  %cNeg1 = arith.constant -1.0 : f64
+  %cNeg1_25 = arith.constant -1.25 : f64
+  %cNeg1_5 = arith.constant -1.5 : f64
+  %cNeg1_75 = arith.constant -1.75 : f64
+  %cNeg2 = arith.constant -2.0 : f64
+  %cNeg2_25 = arith.constant -2.25 : f64
+  %cNeg2_5 = arith.constant -2.5 : f64
+  %cNeg2_75 = arith.constant -2.75 : f64
+  %cNeg3 = arith.constant -3.0 : f64
+  %cNeg3_25 = arith.constant -3.25 : f64
+  %cNeg3_5 = arith.constant -3.5 : f64
+  %cNeg3_75 = arith.constant -3.75 : f64
+
+  // CHECK-NEXT: 0
+  call @func_roundeven64(%c0_25) : (f64) -> ()
+  // CHECK-NEXT: 0
+  call @func_roundeven64(%c0_5) : (f64) -> ()
+  // CHECK-NEXT: 1
+  call @func_roundeven64(%c0_75) : (f64) -> ()
+  // CHECK-NEXT: 1
+  call @func_roundeven64(%c1) : (f64) -> ()
+  // CHECK-NEXT: 1
+  call @func_roundeven64(%c1_25) : (f64) -> ()
+  // CHECK-NEXT: 2
+  call @func_roundeven64(%c1_5) : (f64) -> ()
+  // CHECK-NEXT: 2
+  call @func_roundeven64(%c1_75) : (f64) -> ()
+  // CHECK-NEXT: 2
+  call @func_roundeven64(%c2) : (f64) -> ()
+  // CHECK-NEXT: 2
+  call @func_roundeven64(%c2_25) : (f64) -> ()
+  // CHECK-NEXT: 2
+  call @func_roundeven64(%c2_5) : (f64) -> ()
+  // CHECK-NEXT: 3
+  call @func_roundeven64(%c2_75) : (f64) -> ()
+  // CHECK-NEXT: 3
+  call @func_roundeven64(%c3) : (f64) -> ()
+  // CHECK-NEXT: 3
+  call @func_roundeven64(%c3_25) : (f64) -> ()
+  // CHECK-NEXT: 4
+  call @func_roundeven64(%c3_5) : (f64) -> ()
+  // CHECK-NEXT: 4
+  call @func_roundeven64(%c3_75) : (f64) -> ()
+
+  // CHECK-NEXT: -0
+  call @func_roundeven64(%cNeg0_25) : (f64) -> ()
+  // CHECK-NEXT: -0
+  call @func_roundeven64(%cNeg0_5) : (f64) -> ()
+  // CHECK-NEXT: -1
+  call @func_roundeven64(%cNeg0_75) : (f64) -> ()
+  // CHECK-NEXT: -1
+  call @func_roundeven64(%cNeg1) : (f64) -> ()
+  // CHECK-NEXT: -1
+  call @func_roundeven64(%cNeg1_25) : (f64) -> ()
+  // CHECK-NEXT: -2
+  call @func_roundeven64(%cNeg1_5) : (f64) -> ()
+  // CHECK-NEXT: -2
+  call @func_roundeven64(%cNeg1_75) : (f64) -> ()
+  // CHECK-NEXT: -2
+  call @func_roundeven64(%cNeg2) : (f64) -> ()
+  // CHECK-NEXT: -2
+  call @func_roundeven64(%cNeg2_25) : (f64) -> ()
+  // CHECK-NEXT: -2
+  call @func_roundeven64(%cNeg2_5) : (f64) -> ()
+  // CHECK-NEXT: -3
+  call @func_roundeven64(%cNeg2_75) : (f64) -> ()
+  // CHECK-NEXT: -3
+  call @func_roundeven64(%cNeg3) : (f64) -> ()
+  // CHECK-NEXT: -3
+  call @func_roundeven64(%cNeg3_25) : (f64) -> ()
+  // CHECK-NEXT: -4
+  call @func_roundeven64(%cNeg3_5) : (f64) -> ()
+  // CHECK-NEXT: -4
+  call @func_roundeven64(%cNeg3_75) : (f64) -> ()
+
+  // Special values: 0, -0, inf, -inf, nan, -nan
+  %cNeg0 = arith.constant -0.0 : f64
+  %c0 = arith.constant 0.0 : f64
+  %cInfInt = arith.constant 0x7FF0000000000000 : i64
+  %cInf = arith.bitcast %cInfInt : i64 to f64
+  %cNegInfInt = arith.constant 0xFFF0000000000000 : i64
+  %cNegInf = arith.bitcast %cNegInfInt : i64 to f64
+  %cNanInt = arith.constant 0x7FF0000000000001 : i64
+  %cNan = arith.bitcast %cNanInt : i64 to f64
+  %cNegNanInt = arith.constant 0xFFF0000000000001 : i64
+  %cNegNan = arith.bitcast %cNegNanInt : i64 to f64
+
+  // CHECK-NEXT: -0
+  call @func_roundeven64(%cNeg0) : (f64) -> ()
+  // CHECK-NEXT: 0
+  call @func_roundeven64(%c0) : (f64) -> ()
+  // CHECK-NEXT: inf
+  call @func_roundeven64(%cInf) : (f64) -> ()
+  // CHECK-NEXT: -inf
+  call @func_roundeven64(%cNegInf) : (f64) -> ()
+
+  // Values above and below 2^52 = 4503599627370496
+  %c4503599627370494_5 = arith.constant 4503599627370494.5 : f64
+  %c4503599627370495 = arith.constant 4503599627370495.0 : f64
+  %c4503599627370495_5 = arith.constant 4503599627370495.5 : f64
+  %c4503599627370496 = arith.constant 4503599627370496.0 : f64
+  %c4503599627370497 = arith.constant 4503599627370497.0 : f64
+
+  // Bitcast result to int to avoid printing in scientific notation,
+  // which does not display all significant digits.
+
+  // CHECK-NEXT: 4841369599423283196
+  // hex: 0x432ffffffffffffc
+  call @func_roundeven64$bitcast_result_to_int(%c4503599627370494_5) : (f64) -> ()
+  // CHECK-NEXT: 4841369599423283198
+  // hex: 0x432ffffffffffffe
+  call @func_roundeven64$bitcast_result_to_int(%c4503599627370495) : (f64) -> ()
+  // CHECK-NEXT: 4841369599423283200
+  // hex: 0x4330000000000000
+  call @func_roundeven64$bitcast_result_to_int(%c4503599627370495_5) : (f64) -> ()
+  // CHECK-NEXT: 4841369599423283200
+  // hex: 0x10000000000000
+  call @func_roundeven64$bitcast_result_to_int(%c4503599627370496) : (f64) -> ()
+  // CHECK-NEXT: 4841369599423283201
+  // hex: 0x10000000000001
+  call @func_roundeven64$bitcast_result_to_int(%c4503599627370497) : (f64) -> ()
+
+  // Check that vector type works
+  %cVec = arith.constant dense<[0.5]> : vector<1xf64>
+  // CHECK-NEXT: ( 0 )
+  call @func_roundeven64$vector(%cVec) : (vector<1xf64>) -> ()
+  return
+}
+
+func.func @roundeven() {
+  call @roundeven32() : () -> ()
+  call @roundeven64() : () -> ()
   return
 }
 


        


More information about the Mlir-commits mailing list