[Mlir-commits] [mlir] 87cef78 - Revert "Fix handling of special and large vals in expand pattern for `round`" and "Add pattern that expands `math.roundeven` into `math.round` + arith"
Mehdi Amini
llvmlistbot at llvm.org
Thu Apr 20 23:17:20 PDT 2023
Author: Mehdi Amini
Date: 2023-04-21T00:16:32-06:00
New Revision: 87cef78fa1c7bf6efc544e990894a6062d56abec
URL: https://github.com/llvm/llvm-project/commit/87cef78fa1c7bf6efc544e990894a6062d56abec
DIFF: https://github.com/llvm/llvm-project/commit/87cef78fa1c7bf6efc544e990894a6062d56abec.diff
LOG: Revert "Fix handling of special and large vals in expand pattern for `round`" and "Add pattern that expands `math.roundeven` into `math.round` + arith"
This reverts commit 8d2bae9abdc30e104bab00a4dd0f9d39f5bdda6e and
commit ab2fc9521ec606603412645d4a4b3cf456acd153.
Tests are broken on Mac M2
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 576ace34eac1c..245a11747d5c8 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -22,7 +22,6 @@ void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
-void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index ee8f23cf362b6..a37340d312f51 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -48,14 +48,9 @@ static Value createIntConst(Location loc, Type type, int64_t value,
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
Type opType = operand.getType();
- Type i64Ty = b.getI64Type();
- if (auto shapedTy = dyn_cast<ShapedType>(opType))
- i64Ty = shapedTy.clone(i64Ty);
- Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
+ Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
- // The truncation does not preserve the sign when the truncated
- // value is -0. So here the sign is copied again.
- return b.create<math::CopySignOp>(fpFixedConvert, operand);
+ return fpFixedConvert;
/// Expands tanh op into
@@ -194,59 +189,23 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
static LogicalResult convertRoundOp(math::RoundOp op,
PatternRewriter &rewriter) {
- Location loc = op.getLoc();
- ImplicitLocOpBuilder b(loc, rewriter);
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Type opEType = getElementTypeOrSelf(opType);
- if (!opEType.isF32()) {
- return rewriter.notifyMatchFailure(op, "not a round of f32.");
- }
- Type i32Ty = b.getI32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(opType))
- i32Ty = shapedTy.clone(i32Ty);
- Value half = createFloatConst(loc, opType, 0.5, b);
- Value c23 = createIntConst(loc, i32Ty, 23, b);
- Value c127 = createIntConst(loc, i32Ty, 127, b);
- Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
+ // Creating constants for later use.
+ Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
- Value incrValue = b.create<math::CopySignOp>(half, operand);
+ Value posCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
+ Value incrValue =
+ b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
- Value fpFixedConvert = createTruncatedFPValue(add, b);
- // There are three cases where adding 0.5 to the value and truncating by
- // converting to an i64 does not result in the correct behavior:
- //
- // 1. Special values: +-inf and +-nan
- // Casting these special values to i64 has undefined behavior. To identify
- // these values, we use the fact that these values are the only float
- // values with the maximum possible biased exponent.
- //
- // 2. Large values: 2^23 <= |x| <= INT_64_MAX
- // Adding 0.5 to a float larger than or equal to 2^23 results in precision
- // errors that sometimes round the value up and sometimes round the value
- // down. For example:
- // 8388608.0 + 0.5 = 8388608.0
- // 8388609.0 + 0.5 = 8388610.0
- //
- // 3. Very large values: |x| > INT_64_MAX
- // Casting to i64 a value greater than the max i64 value will overflow the
- // i64 leading to wrong outputs.
- //
- // All three cases satisfy the property `biasedExp >= 23`.
- Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
- Value operandExp = b.create<arith::AndIOp>(
- b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
- Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
- Value isSpecialValOrLargeVal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
- Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
- fpFixedConvert);
- rewriter.replaceOp(op, result);
+ Value fpFixedConvert = createTruncatedFPValue(add, b);
+ rewriter.replaceOp(op, fpFixedConvert);
return success();
@@ -294,129 +253,6 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
return success();
-// Convert `math.roundeven` into `math.round` + arith ops
-static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
- PatternRewriter &rewriter) {
- Location loc = op.getLoc();
- ImplicitLocOpBuilder b(loc, rewriter);
- auto operand = op.getOperand();
- Type operandTy = operand.getType();
- Type resultTy = op.getType();
- Type operandETy = getElementTypeOrSelf(operandTy);
- Type resultETy = getElementTypeOrSelf(resultTy);
- if (!operandETy.isF32() || !resultETy.isF32()) {
- return rewriter.notifyMatchFailure(op, "not a roundeven of f32.");
- }
- Type i32Ty = b.getI32Type();
- Type f32Ty = b.getF32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i32Ty = shapedTy.clone(i32Ty);
- f32Ty = shapedTy.clone(f32Ty);
- }
- Value c1Float = createFloatConst(loc, f32Ty, 1.0, b);
- Value c0 = createIntConst(loc, i32Ty, 0, b);
- Value c1 = createIntConst(loc, i32Ty, 1, b);
- Value cNeg1 = createIntConst(loc, i32Ty, -1, b);
- Value c23 = createIntConst(loc, i32Ty, 23, b);
- Value c31 = createIntConst(loc, i32Ty, 31, b);
- Value c127 = createIntConst(loc, i32Ty, 127, b);
- Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b);
- Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b);
- Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
- Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
- Value round = b.create<math::RoundOp>(operand);
- Value roundBitcast = b.create<arith::BitcastOp>(i32Ty, round);
- // Get biased exponents for operand and round(operand)
- Value operandExp = b.create<arith::AndIOp>(
- b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
- Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
- Value roundExp = b.create<arith::AndIOp>(
- b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
- Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
- auto safeShiftRight = [&](Value x, Value shift) -> Value {
- // Clamp shift to valid range [0, 31] to avoid undefined behavior
- Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
- clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
- return b.create<arith::ShRUIOp>(x, clampedShift);
- };
- auto maskMantissa = [&](Value mantissa,
- Value mantissaMaskRightShift) -> Value {
- Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
- return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
- };
- // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
- // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
- // with `biasedExp > 23` (numbers where there is not enough precision to store
- // decimals) are always even, and they satisfy the even condition trivially
- // since the mantissa without all its bits is zero. The even condition
- // is also true for +-0, since they have `biasedExp = -127` and the entire
- // mantissa is zero. The case of +-1 has to be handled separately. Here
- // we identify these values by noting that +-1 are the only whole numbers with
- // `biasedExp == 0`.
- //
- // The special values +-inf and +-nan also satisfy the same property that
- // whole non-unit even numbers satisfy. In particular, the special values have
- // `biasedExp > 23`, so they get treated as large numbers with no room for
- // decimals, which are always even.
- Value roundBiasedExpEq0 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
- Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
- Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
- Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
- roundIsNotEvenOrSpecialVal =
- b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
- // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
- // integers if the bit at index `biasedExp` starting from the left in the
- // mantissa is 1 and all the bits to the right are zero. Values with
- // `biasedExp >= 23` don't have decimals, so they are never halfway. The
- // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
- // so these are handled separately. In particular, if `biasedExp == -1`, the
- // value is halfway if the entire mantissa is zero.
- Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
- Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
- operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
- Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
- Value operandIsHalfway =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
- expectedOperandMaskedMantissa);
- // Ensure `biasedExp` is in the valid range for half values.
- Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
- Value operandBiasedExpLt23 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
- operandIsHalfway =
- b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
- operandIsHalfway =
- b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
- // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
- // case where `round` rounded in the opposite direction of `roundeven`.
- Value sign = b.create<math::CopySignOp>(c1Float, operand);
- Value roundShifted = b.create<arith::SubFOp>(round, sign);
- // If the rounded value is even or a special value, we default to the behavior
- // of `math.round`.
- Value needsShift =
- b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
- Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
- // The `x - sign` adjustment does not preserve the sign when we are adjusting
- // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
- // rounded to -0.0.
- result = b.create<math::CopySignOp>(result, operand);
- rewriter.replaceOp(op, result);
- return success();
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
@@ -452,7 +288,3 @@ void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundEvenOp);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c28e2141db061..382278c060c8e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -141,10 +141,9 @@ func.func @floorf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST_0:%.+]] = arith.constant -1.000
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
- // CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf olt, [[ARG0]], [[CST]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
- // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
+ // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
%ret = math.floor %a : f64
return %ret : f64
@@ -159,10 +158,9 @@ func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
- // CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
- // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
+ // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[CVTF]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
- // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
+ // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
%ret = math.ceil %a : f64
return %ret : f64
@@ -195,26 +193,19 @@ func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
// -----
// CHECK-LABEL: func @roundf_func
-// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32
-func.func @roundf_func(%a: f32) -> f32 {
- // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
- // CHECK-DAG: %[[C23:.*]] = arith.constant 23
- // CHECK-DAG: %[[C127:.*]] = arith.constant 127
- // CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 255
- // CHECK-DAG: %[[SHIFT:.*]] = math.copysign %[[HALF]], %[[ARG0]]
- // CHECK-DAG: %[[ARG_SHIFTED:.*]] = arith.addf %[[ARG0]], %[[SHIFT]]
- // CHECK-DAG: %[[FIXED_CONVERT:.*]] = arith.fptosi %[[ARG_SHIFTED]]
- // CHECK-DAG: %[[FP_FIXED_CONVERT_0:.*]] = arith.sitofp %[[FIXED_CONVERT]]
- // CHECK-DAG: %[[FP_FIXED_CONVERT_1:.*]] = math.copysign %[[FP_FIXED_CONVERT_0]], %[[ARG_SHIFTED]]
- // CHECK-DAG: %[[ARG_BITCAST:.*]] = arith.bitcast %[[ARG0]] : f32 to i32
- // CHECK-DAG: %[[ARG_BITCAST_SHIFTED:.*]] = arith.shrui %[[ARG_BITCAST]], %[[C23]]
- // CHECK-DAG: %[[ARG_EXP:.*]] = arith.andi %[[ARG_BITCAST_SHIFTED]], %[[EXP_MASK]]
- // CHECK-DAG: %[[ARG_BIASED_EXP:.*]] = arith.subi %[[ARG_EXP]], %[[C127]]
- // CHECK-DAG: %[[IS_SPECIAL_VAL:.*]] = arith.cmpi sge, %[[ARG_BIASED_EXP]], %[[C23]]
- // CHECK-DAG: %[[RESULT:.*]] = arith.select %[[IS_SPECIAL_VAL]], %[[ARG0]], %[[FP_FIXED_CONVERT_1]]
- // CHECK: return %[[RESULT]]
- %ret = math.round %a : f32
- return %ret : f32
+// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
+func.func @roundf_func(%a: f64) -> f64 {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
+ // CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01
+ // CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01
+ // CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
+ // CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
+ // CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
+ // CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]]
+ // CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+ // CHECK: return [[CVTF]]
+ %ret = math.round %a : f64
+ return %ret : f64
// -----
@@ -229,105 +220,3 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
%ret = math.powf %a, %b : f64
return %ret : f64
-// -----
-// CHECK-LABEL: func.func @roundeven
-func.func @roundeven(%arg: f32) -> f32 {
- %res = math.roundeven %arg : f32
- return %res : f32
-// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
-// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i32
-// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG: %[[C_23:.*]] = arith.constant 23 : i32
-// CHECK-DAG: %[[C_31:.*]] = arith.constant 31 : i32
-// CHECK-DAG: %[[C_127:.*]] = arith.constant 127 : i32
-// CHECK-DAG: %[[C_4194304:.*]] = arith.constant 4194304 : i32
-// CHECK-DAG: %[[C_8388607:.*]] = arith.constant 8388607 : i32
-// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 255 : i32
-// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
-// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f32 to i32
-// Calculate `math.round(operand)` using expansion pattern for `round` and
-// bitcast result to i32
-// CHECK: %[[SHIFT:.*]] = math.copysign %[[HALF]], %[[VAL_0]]
-// CHECK: %[[ARG_SHIFTED:.*]] = arith.addf %[[VAL_0]], %[[SHIFT]]
-// CHECK: %[[FIXED_CONVERT:.*]] = arith.fptosi %[[ARG_SHIFTED]]
-// CHECK: %[[FP_FIXED_CONVERT_0:.*]] = arith.sitofp %[[FIXED_CONVERT]]
-// CHECK: %[[FP_FIXED_CONVERT_1:.*]] = math.copysign %[[FP_FIXED_CONVERT_0]], %[[ARG_SHIFTED]]
-// CHECK: %[[ARG_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f32 to i32
-// CHECK: %[[ARG_BITCAST_SHIFTED:.*]] = arith.shrui %[[ARG_BITCAST]], %[[C_23]]
-// CHECK: %[[ARG_EXP:.*]] = arith.andi %[[ARG_BITCAST_SHIFTED]], %[[EXP_MASK]]
-// CHECK: %[[ARG_BIASED_EXP:.*]] = arith.subi %[[ARG_EXP]], %[[C_127]]
-// CHECK: %[[IS_SPECIAL_VAL:.*]] = arith.cmpi sge, %[[ARG_BIASED_EXP]], %[[C_23]]
-// CHECK: %[[ROUND:.*]] = arith.select %[[IS_SPECIAL_VAL]], %[[VAL_0]], %[[FP_FIXED_CONVERT_1]]
-// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f32 to i32
-// Get biased exponents of `round` and `operand`
-// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_23]] : i32
-// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i32
-// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_127]] : i32
-// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_23]] : i32
-// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i32
-// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_127]] : i32
-// Determine if `ROUND_BITCAST` is an even whole number or a special value
-// +-inf, +-nan.
-// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
-// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i32
-// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i32
-// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_31]] : i32
-// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_8388607]], %[[CLAMPED_SHIFT_1]] : i32
-// `ROUND_BITCAST` is not even whole number or special value if masked
-// mantissa is != 0 or `ROUND_BIASED_EXP == 0`
-// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i32
-// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i32
-// Determine if operand is halfway between two integer values
-// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i32
-// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i32
-// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_31]] : i32
-// CHECK: %[[SHIFTED_2_TO_22:.*]] = arith.shrui %[[C_4194304]], %[[CLAMPED_SHIFT_3]] : i32
-// A value with `0 <= BIASED_EXP < 23` is halfway between two consecutive
-// integers if the bit at index `BIASED_EXP` starting from the left in the
-// mantissa is 1 and all the bits to the right are zero. For the case where
-// `BIASED_EXP == -1, the expected mantissa is all zeros.
-// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_22]] : i32
-// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
-// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i32
-// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_31]] : i32
-// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_8388607]], %[[CLAMPED_SHIFT_5]] : i32
-// The operand is halfway between two integers if the masked mantissa is equal
-// to the expected mantissa and the biased exponent is in the range
-// [-1, 23).
-// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i32
-// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_23:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_23]] : i32
-// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_23]] : i1
-// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
-// case where `round` rounded in the oppositve direction of `roundeven`.
-// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f32
-// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f32
-// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
-// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f32
-// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
-// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32
-// CHECK: return %[[COPYSIGN]] : f32
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 6dae8213dd41e..c9b3357c9b508 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -45,7 +45,6 @@ void TestExpandMathPass::runOnOperation() {
- populateExpandRoundEvenPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 2a27d0f6d37c3..16a239ea735c5 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -19,37 +19,37 @@ func.func @exp2f() {
%a = arith.constant 1.0 : f64
call @func_exp2f(%a) : (f64) -> ()
- // CHECK-NEXT: 4
+ // CHECK: 4
%b = arith.constant 2.0 : f64
call @func_exp2f(%b) : (f64) -> ()
- // CHECK-NEXT: 5.65685
+ // CHECK: 5.65685
%c = arith.constant 2.5 : f64
call @func_exp2f(%c) : (f64) -> ()
- // CHECK-NEXT: 0.29730
+ // CHECK: 0.29730
%d = arith.constant -1.75 : f64
call @func_exp2f(%d) : (f64) -> ()
- // CHECK-NEXT: 1.09581
+ // CHECK: 1.09581
%e = arith.constant 0.132 : f64
call @func_exp2f(%e) : (f64) -> ()
- // CHECK-NEXT: inf
+ // CHECK: inf
%f1 = arith.constant 0.00 : f64
%f2 = arith.constant 1.00 : f64
%f = arith.divf %f2, %f1 : f64
call @func_exp2f(%f) : (f64) -> ()
- // CHECK-NEXT: inf
+ // CHECK: inf
%g = arith.constant 5038939.0 : f64
call @func_exp2f(%g) : (f64) -> ()
- // CHECK-NEXT: 0
+ // CHECK: 0
%neg_inf = arith.constant 0xff80000000000000 : f64
call @func_exp2f(%neg_inf) : (f64) -> ()
- // CHECK-NEXT: inf
+ // CHECK: inf
%i = arith.constant 0x7fc0000000000000 : f64
call @func_exp2f(%i) : (f64) -> ()
@@ -64,113 +64,39 @@ func.func @func_roundf(%a : f32) {
-func.func @func_roundf$bitcast_result_to_int(%a : f32) {
- %b = math.round %a : f32
- %c = arith.bitcast %b : f32 to i32
- vector.print %c : i32
- return
-func.func @func_roundf$vector(%a : vector<1xf32>) {
- %b = math.round %a : vector<1xf32>
- vector.print %b : vector<1xf32>
- return
func.func @roundf() {
- // CHECK-NEXT: 4
+ // CHECK: 4
%a = arith.constant 3.8 : f32
call @func_roundf(%a) : (f32) -> ()
- // CHECK-NEXT: -4
+ // CHECK: -4
%b = arith.constant -3.8 : f32
call @func_roundf(%b) : (f32) -> ()
- // CHECK-NEXT: -4
- %c = arith.constant -4.2 : f32
+ // CHECK: 0
+ %c = arith.constant 0.0 : f32
call @func_roundf(%c) : (f32) -> ()
- // CHECK-NEXT: -495
- %d = arith.constant -495.0 : f32
+ // CHECK: -4
+ %d = arith.constant -4.2 : f32
call @func_roundf(%d) : (f32) -> ()
- // CHECK-NEXT: 495
- %e = arith.constant 495.0 : f32
+ // CHECK: -495
+ %e = arith.constant -495.0 : f32
call @func_roundf(%e) : (f32) -> ()
- // CHECK-NEXT: 9
- %f = arith.constant 8.5 : f32
+ // CHECK: 495
+ %f = arith.constant 495.0 : f32
call @func_roundf(%f) : (f32) -> ()
- // CHECK-NEXT: -9
- %g = arith.constant -8.5 : f32
+ // CHECK: 9
+ %g = arith.constant 8.5 : f32
call @func_roundf(%g) : (f32) -> ()
- // CHECK-NEXT: -0
- %h = arith.constant -0.4 : f32
+ // CHECK: -9
+ %h = arith.constant -8.5 : f32
call @func_roundf(%h) : (f32) -> ()
- // Special values: 0, -0, inf, -inf, nan, -nan
- %cNeg0 = arith.constant -0.0 : f32
- %c0 = arith.constant 0.0 : f32
- %cInfInt = arith.constant 0x7f800000 : i32
- %cInf = arith.bitcast %cInfInt : i32 to f32
- %cNegInfInt = arith.constant 0xff800000 : i32
- %cNegInf = arith.bitcast %cNegInfInt : i32 to f32
- %cNanInt = arith.constant 0x7fc00000 : i32
- %cNan = arith.bitcast %cNanInt : i32 to f32
- %cNegNanInt = arith.constant 0xffc00000 : i32
- %cNegNan = arith.bitcast %cNegNanInt : i32 to f32
- // CHECK-NEXT: -0
- call @func_roundf(%cNeg0) : (f32) -> ()
- // CHECK-NEXT: 0
- call @func_roundf(%c0) : (f32) -> ()
- // CHECK-NEXT: inf
- call @func_roundf(%cInf) : (f32) -> ()
- // CHECK-NEXT: -inf
- call @func_roundf(%cNegInf) : (f32) -> ()
- // CHECK-NEXT: nan
- call @func_roundf(%cNan) : (f32) -> ()
- // CHECK-NEXT: -nan
- call @func_roundf(%cNegNan) : (f32) -> ()
- // Very large values (greater than INT_64_MAX)
- %c2To100 = arith.constant 1.268e30 : f32 // 2^100
- // CHECK-NEXT: 1.268e+30
- call @func_roundf(%c2To100) : (f32) -> ()
- // Values above and below 2^23 = 8388608
- %c8388606_5 = arith.constant 8388606.5 : f32
- %c8388607 = arith.constant 8388607.0 : f32
- %c8388607_5 = arith.constant 8388607.5 : f32
- %c8388608 = arith.constant 8388608.0 : f32
- %c8388609 = arith.constant 8388609.0 : f32
- // Bitcast result to int to avoid printing in scientific notation,
- // which does not display all significant digits.
- // CHECK-NEXT: 1258291198
- // hex: 0x4AFFFFFE
- call @func_roundf$bitcast_result_to_int(%c8388606_5) : (f32) -> ()
- // CHECK-NEXT: 1258291198
- // hex: 0x4AFFFFFE
- call @func_roundf$bitcast_result_to_int(%c8388607) : (f32) -> ()
- // CHECK-NEXT: 1258291200
- // hex: 0x4B000000
- call @func_roundf$bitcast_result_to_int(%c8388607_5) : (f32) -> ()
- // CHECK-NEXT: 1258291200
- // hex: 0x4B000000
- call @func_roundf$bitcast_result_to_int(%c8388608) : (f32) -> ()
- // CHECK-NEXT: 1258291201
- // hex: 0x4B000001
- call @func_roundf$bitcast_result_to_int(%c8388609) : (f32) -> ()
- // Check that vector type works
- %cVec = arith.constant dense<[0.5]> : vector<1xf32>
- // CHECK-NEXT: ( 1 )
- call @func_roundf$vector(%cVec) : (vector<1xf32>) -> ()
@@ -184,232 +110,52 @@ func.func @func_powff64(%a : f64, %b : f64) {
func.func @powf() {
- // CHECK-NEXT: 16
+ // CHECK: 16
%a = arith.constant 4.0 : f64
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
- // CHECK-NEXT: nan
+ // CHECK: nan
%b = arith.constant -3.0 : f64
%b_p = arith.constant 3.0 : f64
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
- // CHECK-NEXT: 2.343
+ // CHECK: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
call @func_powff64(%c, %c_p) : (f64, f64) -> ()
- // CHECK-NEXT: 0.176171
+ // CHECK: 0.176171
%d = arith.constant 4.25 : f64
%d_p = arith.constant -1.2 : f64
call @func_powff64(%d, %d_p) : (f64, f64) -> ()
- // CHECK-NEXT: 1
+ // CHECK: 1
%e = arith.constant 4.385 : f64
%e_p = arith.constant 0.00 : f64
call @func_powff64(%e, %e_p) : (f64, f64) -> ()
- // CHECK-NEXT: 6.62637
+ // CHECK: 6.62637
%f = arith.constant 4.835 : f64
%f_p = arith.constant 1.2 : f64
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
- // CHECK-NEXT: nan
+ // CHECK: nan
%g = arith.constant 0xff80000000000000 : f64
call @func_powff64(%g, %g) : (f64, f64) -> ()
- // CHECK-NEXT: nan
+ // CHECK: nan
%h = arith.constant 0x7fffffffffffffff : f64
call @func_powff64(%h, %h) : (f64, f64) -> ()
- // CHECK-NEXT: nan
+ // CHECK: nan
%i = arith.constant 1.0 : f64
call @func_powff64(%i, %h) : (f64, f64) -> ()
- // CHECK-NEXT: inf
+ // CHECK: inf
%j = arith.constant 29385.0 : f64
%j_p = arith.constant 23598.0 : f64
- call @func_powff64(%j, %j_p) : (f64, f64) -> ()
- return
-// -------------------------------------------------------------------------- //
-// roundeven.
-// -------------------------------------------------------------------------- //
-func.func @func_roundeven(%a : f32) {
- %b = math.roundeven %a : f32
- vector.print %b : f32
- return
-func.func @func_roundeven$bitcast_result_to_int(%a : f32) {
- %b = math.roundeven %a : f32
- %c = arith.bitcast %b : f32 to i32
- vector.print %c : i32
- return
-func.func @func_roundeven$vector(%a : vector<1xf32>) {
- %b = math.roundeven %a : vector<1xf32>
- vector.print %b : vector<1xf32>
- return
-func.func @roundeven() {
- %c0_25 = arith.constant 0.25 : f32
- %c0_5 = arith.constant 0.5 : f32
- %c0_75 = arith.constant 0.75 : f32
- %c1 = arith.constant 1.0 : f32
- %c1_25 = arith.constant 1.25 : f32
- %c1_5 = arith.constant 1.5 : f32
- %c1_75 = arith.constant 1.75 : f32
- %c2 = arith.constant 2.0 : f32
- %c2_25 = arith.constant 2.25 : f32
- %c2_5 = arith.constant 2.5 : f32
- %c2_75 = arith.constant 2.75 : f32
- %c3 = arith.constant 3.0 : f32
- %c3_25 = arith.constant 3.25 : f32
- %c3_5 = arith.constant 3.5 : f32
- %c3_75 = arith.constant 3.75 : f32
- %cNeg0_25 = arith.constant -0.25 : f32
- %cNeg0_5 = arith.constant -0.5 : f32
- %cNeg0_75 = arith.constant -0.75 : f32
- %cNeg1 = arith.constant -1.0 : f32
- %cNeg1_25 = arith.constant -1.25 : f32
- %cNeg1_5 = arith.constant -1.5 : f32
- %cNeg1_75 = arith.constant -1.75 : f32
- %cNeg2 = arith.constant -2.0 : f32
- %cNeg2_25 = arith.constant -2.25 : f32
- %cNeg2_5 = arith.constant -2.5 : f32
- %cNeg2_75 = arith.constant -2.75 : f32
- %cNeg3 = arith.constant -3.0 : f32
- %cNeg3_25 = arith.constant -3.25 : f32
- %cNeg3_5 = arith.constant -3.5 : f32
- %cNeg3_75 = arith.constant -3.75 : f32
- // CHECK-NEXT: 0
- call @func_roundeven(%c0_25) : (f32) -> ()
- // CHECK-NEXT: 0
- call @func_roundeven(%c0_5) : (f32) -> ()
- // CHECK-NEXT: 1
- call @func_roundeven(%c0_75) : (f32) -> ()
- // CHECK-NEXT: 1
- call @func_roundeven(%c1) : (f32) -> ()
- // CHECK-NEXT: 1
- call @func_roundeven(%c1_25) : (f32) -> ()
- // CHECK-NEXT: 2
- call @func_roundeven(%c1_5) : (f32) -> ()
- // CHECK-NEXT: 2
- call @func_roundeven(%c1_75) : (f32) -> ()
- // CHECK-NEXT: 2
- call @func_roundeven(%c2) : (f32) -> ()
- // CHECK-NEXT: 2
- call @func_roundeven(%c2_25) : (f32) -> ()
- // CHECK-NEXT: 2
- call @func_roundeven(%c2_5) : (f32) -> ()
- // CHECK-NEXT: 3
- call @func_roundeven(%c2_75) : (f32) -> ()
- // CHECK-NEXT: 3
- call @func_roundeven(%c3) : (f32) -> ()
- // CHECK-NEXT: 3
- call @func_roundeven(%c3_25) : (f32) -> ()
- // CHECK-NEXT: 4
- call @func_roundeven(%c3_5) : (f32) -> ()
- // CHECK-NEXT: 4
- call @func_roundeven(%c3_75) : (f32) -> ()
- // CHECK-NEXT: -0
- call @func_roundeven(%cNeg0_25) : (f32) -> ()
- // CHECK-NEXT: -0
- call @func_roundeven(%cNeg0_5) : (f32) -> ()
- // CHECK-NEXT: -1
- call @func_roundeven(%cNeg0_75) : (f32) -> ()
- // CHECK-NEXT: -1
- call @func_roundeven(%cNeg1) : (f32) -> ()
- // CHECK-NEXT: -1
- call @func_roundeven(%cNeg1_25) : (f32) -> ()
- // CHECK-NEXT: -2
- call @func_roundeven(%cNeg1_5) : (f32) -> ()
- // CHECK-NEXT: -2
- call @func_roundeven(%cNeg1_75) : (f32) -> ()
- // CHECK-NEXT: -2
- call @func_roundeven(%cNeg2) : (f32) -> ()
- // CHECK-NEXT: -2
- call @func_roundeven(%cNeg2_25) : (f32) -> ()
- // CHECK-NEXT: -2
- call @func_roundeven(%cNeg2_5) : (f32) -> ()
- // CHECK-NEXT: -3
- call @func_roundeven(%cNeg2_75) : (f32) -> ()
- // CHECK-NEXT: -3
- call @func_roundeven(%cNeg3) : (f32) -> ()
- // CHECK-NEXT: -3
- call @func_roundeven(%cNeg3_25) : (f32) -> ()
- // CHECK-NEXT: -4
- call @func_roundeven(%cNeg3_5) : (f32) -> ()
- // CHECK-NEXT: -4
- call @func_roundeven(%cNeg3_75) : (f32) -> ()
- // Special values: 0, -0, inf, -inf, nan, -nan
- %cNeg0 = arith.constant -0.0 : f32
- %c0 = arith.constant 0.0 : f32
- %cInfInt = arith.constant 0x7f800000 : i32
- %cInf = arith.bitcast %cInfInt : i32 to f32
- %cNegInfInt = arith.constant 0xff800000 : i32
- %cNegInf = arith.bitcast %cNegInfInt : i32 to f32
- %cNanInt = arith.constant 0x7fc00000 : i32
- %cNan = arith.bitcast %cNanInt : i32 to f32
- %cNegNanInt = arith.constant 0xffc00000 : i32
- %cNegNan = arith.bitcast %cNegNanInt : i32 to f32
- // CHECK-NEXT: -0
- call @func_roundeven(%cNeg0) : (f32) -> ()
- // CHECK-NEXT: 0
- call @func_roundeven(%c0) : (f32) -> ()
- // CHECK-NEXT: inf
- call @func_roundeven(%cInf) : (f32) -> ()
- // CHECK-NEXT: -inf
- call @func_roundeven(%cNegInf) : (f32) -> ()
- // CHECK-NEXT: nan
- call @func_roundeven(%cNan) : (f32) -> ()
- // CHECK-NEXT: -nan
- call @func_roundeven(%cNegNan) : (f32) -> ()
- // Values above and below 2^23 = 8388608
- %c8388606_5 = arith.constant 8388606.5 : f32
- %c8388607 = arith.constant 8388607.0 : f32
- %c8388607_5 = arith.constant 8388607.5 : f32
- %c8388608 = arith.constant 8388608.0 : f32
- %c8388609 = arith.constant 8388609.0 : f32
- // Bitcast result to int to avoid printing in scientific notation,
- // which does not display all significant digits.
- // CHECK-NEXT: 1258291196
- // hex: 0x4AFFFFFC
- call @func_roundeven$bitcast_result_to_int(%c8388606_5) : (f32) -> ()
- // CHECK-NEXT: 1258291198
- // hex: 0x4AFFFFFE
- call @func_roundeven$bitcast_result_to_int(%c8388607) : (f32) -> ()
- // CHECK-NEXT: 1258291200
- // hex: 0x4B000000
- call @func_roundeven$bitcast_result_to_int(%c8388607_5) : (f32) -> ()
- // CHECK-NEXT: 1258291200
- // hex: 0x4B000000
- call @func_roundeven$bitcast_result_to_int(%c8388608) : (f32) -> ()
- // CHECK-NEXT: 1258291201
- // hex: 0x4B000001
- call @func_roundeven$bitcast_result_to_int(%c8388609) : (f32) -> ()
- // Check that vector type works
- %cVec = arith.constant dense<[0.5]> : vector<1xf32>
- // CHECK-NEXT: ( 0 )
- call @func_roundeven$vector(%cVec) : (vector<1xf32>) -> ()
+ call @func_powff64(%j, %j_p) : (f64, f64) -> ()
@@ -417,6 +163,5 @@ func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
call @powf() : () -> ()
- call @roundeven() : () -> ()
More information about the Mlir-commits
mailing list