[Mlir-commits] [mlir] 8d2bae9 - Add pattern that expands `math.roundeven` into `math.round` + arith
Jacques Pienaar
llvmlistbot at llvm.org
Thu Apr 20 12:48:27 PDT 2023
Author: Ramiro Leal-Cavazos
Date: 2023-04-20T12:48:12-07:00
New Revision: 8d2bae9abdc30e104bab00a4dd0f9d39f5bdda6e
URL: https://github.com/llvm/llvm-project/commit/8d2bae9abdc30e104bab00a4dd0f9d39f5bdda6e
DIFF: https://github.com/llvm/llvm-project/commit/8d2bae9abdc30e104bab00a4dd0f9d39f5bdda6e.diff
LOG: Add pattern that expands `math.roundeven` into `math.round` + arith
This commit adds a pattern that expands `math.roundeven` into
`math.round` + some ops from `arith`. This is needed to be able to run
`math.roundeven` in a vectorized manner.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D148285
Added:
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 245a11747d5c8..576ace34eac1c 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -22,6 +22,7 @@ void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
+void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 62337284e1afb..ee8f23cf362b6 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -294,6 +294,129 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
return success();
}
+// Convert `math.roundeven` into `math.round` + arith ops
+static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
+ PatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isF32() || !resultETy.isF32()) {
+ return rewriter.notifyMatchFailure(op, "not a roundeven of f32.");
+ }
+
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+
+ Value c1Float = createFloatConst(loc, f32Ty, 1.0, b);
+ Value c0 = createIntConst(loc, i32Ty, 0, b);
+ Value c1 = createIntConst(loc, i32Ty, 1, b);
+ Value cNeg1 = createIntConst(loc, i32Ty, -1, b);
+ Value c23 = createIntConst(loc, i32Ty, 23, b);
+ Value c31 = createIntConst(loc, i32Ty, 31, b);
+ Value c127 = createIntConst(loc, i32Ty, 127, b);
+ Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b);
+ Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b);
+ Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
+
+ Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value round = b.create<math::RoundOp>(operand);
+ Value roundBitcast = b.create<arith::BitcastOp>(i32Ty, round);
+
+ // Get biased exponents for operand and round(operand)
+ Value operandExp = b.create<arith::AndIOp>(
+ b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
+ Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
+ Value roundExp = b.create<arith::AndIOp>(
+ b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
+ Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
+
+ auto safeShiftRight = [&](Value x, Value shift) -> Value {
+ // Clamp shift to valid range [0, 31] to avoid undefined behavior
+ Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
+ clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
+ return b.create<arith::ShRUIOp>(x, clampedShift);
+ };
+
+ auto maskMantissa = [&](Value mantissa,
+ Value mantissaMaskRightShift) -> Value {
+ Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
+ return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
+ };
+
+ // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
+ // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
+ // with `biasedExp > 23` (numbers where there is not enough precision to store
+ // decimals) are always even, and they satisfy the even condition trivially
+ // since the mantissa without all its bits is zero. The even condition
+ // is also true for +-0, since they have `biasedExp = -127` and the entire
+ // mantissa is zero. The case of +-1 has to be handled separately. Here
+ // we identify these values by noting that +-1 are the only whole numbers with
+ // `biasedExp == 0`.
+ //
+ // The special values +-inf and +-nan also satisfy the same property that
+ // whole non-unit even numbers satisfy. In particular, the special values have
+ // `biasedExp > 23`, so they get treated as large numbers with no room for
+ // decimals, which are always even.
+ Value roundBiasedExpEq0 =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
+ Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
+ Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
+ Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
+ roundIsNotEvenOrSpecialVal =
+ b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
+
+ // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
+ // integers if the bit at index `biasedExp` starting from the left in the
+ // mantissa is 1 and all the bits to the right are zero. Values with
+ // `biasedExp >= 23` don't have decimals, so they are never halfway. The
+ // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
+ // so these are handled separately. In particular, if `biasedExp == -1`, the
+ // value is halfway if the entire mantissa is zero.
+ Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
+ Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
+ operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
+ Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
+ Value operandIsHalfway =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
+ expectedOperandMaskedMantissa);
+ // Ensure `biasedExp` is in the valid range for half values.
+ Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
+ Value operandBiasedExpLt23 =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
+ operandIsHalfway =
+ b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
+ operandIsHalfway =
+ b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
+
+ // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
+ // case where `round` rounded in the opposite direction of `roundeven`.
+ Value sign = b.create<math::CopySignOp>(c1Float, operand);
+ Value roundShifted = b.create<arith::SubFOp>(round, sign);
+ // If the rounded value is even or a special value, we default to the behavior
+ // of `math.round`.
+ Value needsShift =
+ b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
+ Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
+ // The `x - sign` adjustment does not preserve the sign when we are adjusting
+ // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
+ // rounded to -0.0.
+ result = b.create<math::CopySignOp>(result, operand);
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
patterns.add(convertCtlzOp);
}
@@ -329,3 +452,7 @@ void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}
+
+void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
+ patterns.add(convertRoundEvenOp);
+}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1e9a3da2ac9c9..c28e2141db061 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -229,3 +229,105 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
%ret = math.powf %a, %b : f64
return %ret : f64
}
+
+// -----
+
+// CHECK-LABEL: func.func @roundeven
+func.func @roundeven(%arg: f32) -> f32 {
+ %res = math.roundeven %arg : f32
+ return %res : f32
+}
+
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[C_23:.*]] = arith.constant 23 : i32
+// CHECK-DAG: %[[C_31:.*]] = arith.constant 31 : i32
+// CHECK-DAG: %[[C_127:.*]] = arith.constant 127 : i32
+// CHECK-DAG: %[[C_4194304:.*]] = arith.constant 4194304 : i32
+// CHECK-DAG: %[[C_8388607:.*]] = arith.constant 8388607 : i32
+// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 255 : i32
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
+
+// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f32 to i32
+
+// Calculate `math.round(operand)` using expansion pattern for `round` and
+// bitcast result to i32
+// CHECK: %[[SHIFT:.*]] = math.copysign %[[HALF]], %[[VAL_0]]
+// CHECK: %[[ARG_SHIFTED:.*]] = arith.addf %[[VAL_0]], %[[SHIFT]]
+// CHECK: %[[FIXED_CONVERT:.*]] = arith.fptosi %[[ARG_SHIFTED]]
+// CHECK: %[[FP_FIXED_CONVERT_0:.*]] = arith.sitofp %[[FIXED_CONVERT]]
+// CHECK: %[[FP_FIXED_CONVERT_1:.*]] = math.copysign %[[FP_FIXED_CONVERT_0]], %[[ARG_SHIFTED]]
+// CHECK: %[[ARG_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f32 to i32
+// CHECK: %[[ARG_BITCAST_SHIFTED:.*]] = arith.shrui %[[ARG_BITCAST]], %[[C_23]]
+// CHECK: %[[ARG_EXP:.*]] = arith.andi %[[ARG_BITCAST_SHIFTED]], %[[EXP_MASK]]
+// CHECK: %[[ARG_BIASED_EXP:.*]] = arith.subi %[[ARG_EXP]], %[[C_127]]
+// CHECK: %[[IS_SPECIAL_VAL:.*]] = arith.cmpi sge, %[[ARG_BIASED_EXP]], %[[C_23]]
+// CHECK: %[[ROUND:.*]] = arith.select %[[IS_SPECIAL_VAL]], %[[VAL_0]], %[[FP_FIXED_CONVERT_1]]
+// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f32 to i32
+
+// Get biased exponents of `round` and `operand`
+// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_23]] : i32
+// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i32
+// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_127]] : i32
+// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_23]] : i32
+// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i32
+// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_127]] : i32
+
+// Determine if `ROUND_BITCAST` is an even whole number or a special value
+// +-inf, +-nan.
+// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
+// `ROUND_BIASED_EXP - 1`
+// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i32
+// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i32
+// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_31]] : i32
+// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_8388607]], %[[CLAMPED_SHIFT_1]] : i32
+// CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i32
+
+// `ROUND_BITCAST` is not even whole number or special value if masked
+// mantissa is != 0 or `ROUND_BIASED_EXP == 0`
+// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i32
+// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i32
+// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
+
+// Determine if operand is halfway between two integer values
+// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i32
+// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i32
+// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_31]] : i32
+// CHECK: %[[SHIFTED_2_TO_22:.*]] = arith.shrui %[[C_4194304]], %[[CLAMPED_SHIFT_3]] : i32
+
+// A value with `0 <= BIASED_EXP < 23` is halfway between two consecutive
+// integers if the bit at index `BIASED_EXP` starting from the left in the
+// mantissa is 1 and all the bits to the right are zero. For the case where
+// `BIASED_EXP == -1, the expected mantissa is all zeros.
+// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_22]] : i32
+
+// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
+// `OPERAND_BIASED_EXP`
+// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i32
+// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_31]] : i32
+// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_8388607]], %[[CLAMPED_SHIFT_5]] : i32
+// CHECK: %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i32
+
+// The operand is halfway between two integers if the masked mantissa is equal
+// to the expected mantissa and the biased exponent is in the range
+// [-1, 23).
+// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i32
+// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_23:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_23]] : i32
+// CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i32
+// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_23]] : i1
+// CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
+
+// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
+// case where `round` rounded in the oppositve direction of `roundeven`.
+// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f32
+// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f32
+// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
+// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f32
+
+// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
+// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32
+
+// CHECK: return %[[COPYSIGN]] : f32
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index c9b3357c9b508..6dae8213dd41e 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -45,6 +45,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandCeilFPattern(patterns);
populateExpandPowFPattern(patterns);
populateExpandRoundFPattern(patterns);
+ populateExpandRoundEvenPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 6d655fe4bb737..2a27d0f6d37c3 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -233,9 +233,190 @@ func.func @powf() {
return
}
+// -------------------------------------------------------------------------- //
+// roundeven.
+// -------------------------------------------------------------------------- //
+
+func.func @func_roundeven(%a : f32) {
+ %b = math.roundeven %a : f32
+ vector.print %b : f32
+ return
+}
+
+func.func @func_roundeven$bitcast_result_to_int(%a : f32) {
+ %b = math.roundeven %a : f32
+ %c = arith.bitcast %b : f32 to i32
+ vector.print %c : i32
+ return
+}
+
+func.func @func_roundeven$vector(%a : vector<1xf32>) {
+ %b = math.roundeven %a : vector<1xf32>
+ vector.print %b : vector<1xf32>
+ return
+}
+
+func.func @roundeven() {
+ %c0_25 = arith.constant 0.25 : f32
+ %c0_5 = arith.constant 0.5 : f32
+ %c0_75 = arith.constant 0.75 : f32
+ %c1 = arith.constant 1.0 : f32
+ %c1_25 = arith.constant 1.25 : f32
+ %c1_5 = arith.constant 1.5 : f32
+ %c1_75 = arith.constant 1.75 : f32
+ %c2 = arith.constant 2.0 : f32
+ %c2_25 = arith.constant 2.25 : f32
+ %c2_5 = arith.constant 2.5 : f32
+ %c2_75 = arith.constant 2.75 : f32
+ %c3 = arith.constant 3.0 : f32
+ %c3_25 = arith.constant 3.25 : f32
+ %c3_5 = arith.constant 3.5 : f32
+ %c3_75 = arith.constant 3.75 : f32
+
+ %cNeg0_25 = arith.constant -0.25 : f32
+ %cNeg0_5 = arith.constant -0.5 : f32
+ %cNeg0_75 = arith.constant -0.75 : f32
+ %cNeg1 = arith.constant -1.0 : f32
+ %cNeg1_25 = arith.constant -1.25 : f32
+ %cNeg1_5 = arith.constant -1.5 : f32
+ %cNeg1_75 = arith.constant -1.75 : f32
+ %cNeg2 = arith.constant -2.0 : f32
+ %cNeg2_25 = arith.constant -2.25 : f32
+ %cNeg2_5 = arith.constant -2.5 : f32
+ %cNeg2_75 = arith.constant -2.75 : f32
+ %cNeg3 = arith.constant -3.0 : f32
+ %cNeg3_25 = arith.constant -3.25 : f32
+ %cNeg3_5 = arith.constant -3.5 : f32
+ %cNeg3_75 = arith.constant -3.75 : f32
+
+ // CHECK-NEXT: 0
+ call @func_roundeven(%c0_25) : (f32) -> ()
+ // CHECK-NEXT: 0
+ call @func_roundeven(%c0_5) : (f32) -> ()
+ // CHECK-NEXT: 1
+ call @func_roundeven(%c0_75) : (f32) -> ()
+ // CHECK-NEXT: 1
+ call @func_roundeven(%c1) : (f32) -> ()
+ // CHECK-NEXT: 1
+ call @func_roundeven(%c1_25) : (f32) -> ()
+ // CHECK-NEXT: 2
+ call @func_roundeven(%c1_5) : (f32) -> ()
+ // CHECK-NEXT: 2
+ call @func_roundeven(%c1_75) : (f32) -> ()
+ // CHECK-NEXT: 2
+ call @func_roundeven(%c2) : (f32) -> ()
+ // CHECK-NEXT: 2
+ call @func_roundeven(%c2_25) : (f32) -> ()
+ // CHECK-NEXT: 2
+ call @func_roundeven(%c2_5) : (f32) -> ()
+ // CHECK-NEXT: 3
+ call @func_roundeven(%c2_75) : (f32) -> ()
+ // CHECK-NEXT: 3
+ call @func_roundeven(%c3) : (f32) -> ()
+ // CHECK-NEXT: 3
+ call @func_roundeven(%c3_25) : (f32) -> ()
+ // CHECK-NEXT: 4
+ call @func_roundeven(%c3_5) : (f32) -> ()
+ // CHECK-NEXT: 4
+ call @func_roundeven(%c3_75) : (f32) -> ()
+
+ // CHECK-NEXT: -0
+ call @func_roundeven(%cNeg0_25) : (f32) -> ()
+ // CHECK-NEXT: -0
+ call @func_roundeven(%cNeg0_5) : (f32) -> ()
+ // CHECK-NEXT: -1
+ call @func_roundeven(%cNeg0_75) : (f32) -> ()
+ // CHECK-NEXT: -1
+ call @func_roundeven(%cNeg1) : (f32) -> ()
+ // CHECK-NEXT: -1
+ call @func_roundeven(%cNeg1_25) : (f32) -> ()
+ // CHECK-NEXT: -2
+ call @func_roundeven(%cNeg1_5) : (f32) -> ()
+ // CHECK-NEXT: -2
+ call @func_roundeven(%cNeg1_75) : (f32) -> ()
+ // CHECK-NEXT: -2
+ call @func_roundeven(%cNeg2) : (f32) -> ()
+ // CHECK-NEXT: -2
+ call @func_roundeven(%cNeg2_25) : (f32) -> ()
+ // CHECK-NEXT: -2
+ call @func_roundeven(%cNeg2_5) : (f32) -> ()
+ // CHECK-NEXT: -3
+ call @func_roundeven(%cNeg2_75) : (f32) -> ()
+ // CHECK-NEXT: -3
+ call @func_roundeven(%cNeg3) : (f32) -> ()
+ // CHECK-NEXT: -3
+ call @func_roundeven(%cNeg3_25) : (f32) -> ()
+ // CHECK-NEXT: -4
+ call @func_roundeven(%cNeg3_5) : (f32) -> ()
+ // CHECK-NEXT: -4
+ call @func_roundeven(%cNeg3_75) : (f32) -> ()
+
+
+ // Special values: 0, -0, inf, -inf, nan, -nan
+ %cNeg0 = arith.constant -0.0 : f32
+ %c0 = arith.constant 0.0 : f32
+ %cInfInt = arith.constant 0x7f800000 : i32
+ %cInf = arith.bitcast %cInfInt : i32 to f32
+ %cNegInfInt = arith.constant 0xff800000 : i32
+ %cNegInf = arith.bitcast %cNegInfInt : i32 to f32
+ %cNanInt = arith.constant 0x7fc00000 : i32
+ %cNan = arith.bitcast %cNanInt : i32 to f32
+ %cNegNanInt = arith.constant 0xffc00000 : i32
+ %cNegNan = arith.bitcast %cNegNanInt : i32 to f32
+
+ // CHECK-NEXT: -0
+ call @func_roundeven(%cNeg0) : (f32) -> ()
+ // CHECK-NEXT: 0
+ call @func_roundeven(%c0) : (f32) -> ()
+ // CHECK-NEXT: inf
+ call @func_roundeven(%cInf) : (f32) -> ()
+ // CHECK-NEXT: -inf
+ call @func_roundeven(%cNegInf) : (f32) -> ()
+ // CHECK-NEXT: nan
+ call @func_roundeven(%cNan) : (f32) -> ()
+ // CHECK-NEXT: -nan
+ call @func_roundeven(%cNegNan) : (f32) -> ()
+
+
+ // Values above and below 2^23 = 8388608
+ %c8388606_5 = arith.constant 8388606.5 : f32
+ %c8388607 = arith.constant 8388607.0 : f32
+ %c8388607_5 = arith.constant 8388607.5 : f32
+ %c8388608 = arith.constant 8388608.0 : f32
+ %c8388609 = arith.constant 8388609.0 : f32
+
+ // Bitcast result to int to avoid printing in scientific notation,
+ // which does not display all significant digits.
+
+ // CHECK-NEXT: 1258291196
+ // hex: 0x4AFFFFFC
+ call @func_roundeven$bitcast_result_to_int(%c8388606_5) : (f32) -> ()
+ // CHECK-NEXT: 1258291198
+ // hex: 0x4AFFFFFE
+ call @func_roundeven$bitcast_result_to_int(%c8388607) : (f32) -> ()
+ // CHECK-NEXT: 1258291200
+ // hex: 0x4B000000
+ call @func_roundeven$bitcast_result_to_int(%c8388607_5) : (f32) -> ()
+ // CHECK-NEXT: 1258291200
+ // hex: 0x4B000000
+ call @func_roundeven$bitcast_result_to_int(%c8388608) : (f32) -> ()
+ // CHECK-NEXT: 1258291201
+ // hex: 0x4B000001
+ call @func_roundeven$bitcast_result_to_int(%c8388609) : (f32) -> ()
+
+
+ // Check that vector type works
+ %cVec = arith.constant dense<[0.5]> : vector<1xf32>
+ // CHECK-NEXT: ( 0 )
+ call @func_roundeven$vector(%cVec) : (vector<1xf32>) -> ()
+
+ return
+}
+
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
call @powf() : () -> ()
+ call @roundeven() : () -> ()
return
}
More information about the Mlir-commits
mailing list