[Mlir-commits] [mlir] [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (PR #124402)
Hyunsung Lee
llvmlistbot at llvm.org
Tue Jan 28 18:02:12 PST 2025
https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/124402
>From 19ba511db07db884ac124b3da4551140cd29ecbf Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:01:11 +0900
Subject: [PATCH 1/2] Update ExpandPatterns.cpp
---
.../Math/Transforms/ExpandPatterns.cpp | 34 +++----
mlir/test/Dialect/Math/expand-math.mlir | 88 +++++++++----------
2 files changed, 59 insertions(+), 63 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f6..df74bc3982298c 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,7 +311,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|))
+// * sign(a)^b
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -319,32 +320,31 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
- Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
- Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
+ Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
- Value logA = b.create<math::LogOp>(opType, opASquared);
- Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+ Value absA = b.create<math::AbsFOp>(opType, operandA);
+ Value isNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value signA =
+ b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
+ Value logA = b.create<math::LogOp>(opType, absA);
+ Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
- Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value negCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value oddPower =
+ Value isOdd =
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+ Value signedExpResult = b.create<arith::SelectOp>(
+ op->getLoc(), isOdd, b.create<arith::MulFOp>(opType, expResult, signA),
+ expResult);
- // First, we select between the exp value and the adjusted value for odd
- // powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
- expResult);
- res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
- rewriter.replaceOp(op, res);
+ Value finalResult =
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
+ rewriter.replaceOp(op, finalResult);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84c..a6c6e51ab88e25 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,24 +202,23 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
- // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
- // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
- // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
- // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
- // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
- // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
- // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
- // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
- // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
- // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
- // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
- // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
- // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
- // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
- // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+ // CHECK-DAG: [[CSTNEG1:%.+]] = arith.constant -1.000000e+00
+ // CHECK-DAG: [[CSTTWO:%.+]] = arith.constant 2.000000e+00
+ // CHECK: [[ABSA:%.+]] = math.absf [[ARG0]]
+ // CHECK: [[ISNEG:%.+]] = arith.cmpf olt, [[ARG0]], [[CST0]]
+ // CHECK: [[SIGNA:%.+]] = arith.select [[ISNEG]], [[CSTNEG1]], [[CST1]]
+ // CHECK: [[LOGA:%.+]] = math.log [[ABSA]]
+ // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
+ // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
+ // CHECK: [[REM:%.+]] = arith.remf [[ARG1]], [[CSTTWO]]
+ // CHECK: [[CMPF:%.+]] = arith.cmpf one, [[REM]], [[CST0]]
+ // CHECK: [[ABMUL:%.+]] = arith.mulf [[EXP]], [[SIGNA]]
+ // CHECK: [[SEL0:%.+]] = arith.select [[CMPF]], [[ABMUL]], [[EXP]]
+ // CHECK: [[CMPF2:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+// CHECK: [[SEL1:%.+]] = arith.select [[CMPF2]], [[CST1]], [[SEL0]]
// CHECK: return [[SEL1]]
%ret = math.powf %a, %b : f64
return %ret : f64
@@ -602,26 +601,24 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK: return %[[SEL1]] : tensor<8xf32>
-
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : tensor<8xf32>
+// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : tensor<8xf32>
+// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : tensor<8xf32>
+// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL1]]
// -----
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -635,19 +632,18 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : f32
+// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : f32
+// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : f32
+// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : f32
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : f32
+// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : f32
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : f32
// CHECK: return %[[SEL1]] : f32
// -----
>From c477e70b97bf5c9a066af38e9991478ae6c79d64 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 29 Jan 2025 10:59:04 +0900
Subject: [PATCH 2/2] remove case a<0
---
.../Math/Transforms/ExpandPatterns.cpp | 25 +++-----
mlir/test/Dialect/Math/expand-math.mlir | 59 +++++--------------
.../mlir-runner/test-expand-math-approx.mlir | 5 --
3 files changed, 23 insertions(+), 66 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index df74bc3982298c..30bcdfc45837a6 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,8 +311,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|))
-// * sign(a)^b
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -320,30 +320,19 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
- Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
-
- Value absA = b.create<math::AbsFOp>(opType, operandA);
- Value isNegative =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value signA =
- b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
- Value logA = b.create<math::LogOp>(opType, absA);
+
+ Value logA = b.create<math::LogOp>(opType, operandA);
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
- Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value isOdd =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value signedExpResult = b.create<arith::SelectOp>(
- op->getLoc(), isOdd, b.create<arith::MulFOp>(opType, expResult, signA),
- expResult);
+ // First, we select between the exp value and the adjusted value for odd
+ // powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value finalResult =
- b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
rewriter.replaceOp(op, finalResult);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index a6c6e51ab88e25..5b443e9e8d4e78 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -205,21 +205,12 @@ func.func @roundf_func(%a: f32) -> f32 {
func.func @powf_func(%a: f64, %b: f64) -> f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
- // CHECK-DAG: [[CSTNEG1:%.+]] = arith.constant -1.000000e+00
- // CHECK-DAG: [[CSTTWO:%.+]] = arith.constant 2.000000e+00
- // CHECK: [[ABSA:%.+]] = math.absf [[ARG0]]
- // CHECK: [[ISNEG:%.+]] = arith.cmpf olt, [[ARG0]], [[CST0]]
- // CHECK: [[SIGNA:%.+]] = arith.select [[ISNEG]], [[CSTNEG1]], [[CST1]]
- // CHECK: [[LOGA:%.+]] = math.log [[ABSA]]
+ // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
// CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
// CHECK: [[EXP:%.+]] = math.exp [[MULB]]
- // CHECK: [[REM:%.+]] = arith.remf [[ARG1]], [[CSTTWO]]
- // CHECK: [[CMPF:%.+]] = arith.cmpf one, [[REM]], [[CST0]]
- // CHECK: [[ABMUL:%.+]] = arith.mulf [[EXP]], [[SIGNA]]
- // CHECK: [[SEL0:%.+]] = arith.select [[CMPF]], [[ABMUL]], [[EXP]]
- // CHECK: [[CMPF2:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-// CHECK: [[SEL1:%.+]] = arith.select [[CMPF2]], [[CST1]], [[SEL0]]
- // CHECK: return [[SEL1]]
+ // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+ // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
+ // CHECK: return [[SEL]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -601,24 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : tensor<8xf32>
-// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
-// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : tensor<8xf32>
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : tensor<8xf32>
-// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL1]]
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL]]
// -----
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -627,24 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : f32
-// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : f32
-// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : f32
-// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
-// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : f32
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : f32
-// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : f32
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
-// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : f32
-// CHECK: return %[[SEL1]] : f32
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
+// CHECK: return %[[SEL]] : f32
// -----
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index 106b48a2daea2e..d1916c28878b97 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,11 +202,6 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
- // CHECK-NEXT: -27
- %b = arith.constant -3.0 : f64
- %b_p = arith.constant 3.0 : f64
- call @func_powff64(%b, %b_p) : (f64, f64) -> ()
-
// CHECK-NEXT: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
More information about the Mlir-commits
mailing list