[Mlir-commits] [mlir] [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (PR #124402)

Hyunsung Lee llvmlistbot at llvm.org
Sat Jan 25 15:45:50 PST 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/124402

>From 5e65430b11da9d739b8b0bd82a0d51bdaaefc6cd Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:01:11 +0900
Subject: [PATCH 1/7] Update ExpandPatterns.cpp

---
 .../Math/Transforms/ExpandPatterns.cpp        | 45 +++++++++++--------
 1 file changed, 27 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f6..314b5b30202064 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,34 +317,43 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
+
+  // Constants
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
-  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
-  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
+  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+
+  // Compute |a| (absolute value of operandA)
+  Value absA = b.create<math::AbsFOp>(opType, operandA);
+
+  // Compute sign(a) as -1.0 if a < 0, else 1.0
+  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
 
-  Value logA = b.create<math::LogOp>(opType, opASquared);
-  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+  // Compute ln(|a|)
+  Value logA = b.create<math::LogOp>(opType, absA);
+
+  // Compute b * ln(|a|)
+  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+
+  // Compute exp(b * ln(|a|))
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
-  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value negCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value oddPower =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+  Value logSign = b.create<math::LogOp>(opType, signA);
+  Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
+  Value signPow = b.create<math::ExpOp>(opType, signMult);
+
+  Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
-  Value zeroCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
-                                        expResult);
-  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
-  rewriter.replaceOp(op, res);
+  Value zeroCheck = 
+    b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+  Value finalResult = 
+    b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+  rewriter.replaceOp(op, finalResult);
   return success();
 }
 

>From ce55a3fb40b2b010aba6486f698548c346924c26 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:16:39 +0900
Subject: [PATCH 2/7] formatting

---
 .../Dialect/Math/Transforms/ExpandPatterns.cpp    | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 314b5b30202064..cdfafc2db52535 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -318,7 +318,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
 
-  // Constants
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
@@ -328,8 +327,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value absA = b.create<math::AbsFOp>(opType, operandA);
 
   // Compute sign(a) as -1.0 if a < 0, else 1.0
-  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
+  Value isNegative =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value signA =
+      b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
 
   // Compute ln(|a|)
   Value logA = b.create<math::LogOp>(opType, absA);
@@ -349,10 +350,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
-  Value zeroCheck = 
-    b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value finalResult = 
-    b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+  Value zeroCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+  Value finalResult =
+      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
   rewriter.replaceOp(op, finalResult);
   return success();
 }

>From 20ce34ecd08dffc3f98d53a55c8d338fa96c025f Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:36:57 +0900
Subject: [PATCH 3/7] .

---
 .../Math/Transforms/ExpandPatterns.cpp        | 32 ++++++-------------
 1 file changed, 10 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index cdfafc2db52535..f2b32a25316c23 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,7 +311,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|)) * sign(a)^b
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -323,41 +323,29 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
 
-  // Compute |a| (absolute value of operandA)
   Value absA = b.create<math::AbsFOp>(opType, operandA);
-
-  // Compute sign(a) as -1.0 if a < 0, else 1.0
-  Value isNegative =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value signA =
-      b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
-
-  // Compute ln(|a|)
+  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
   Value logA = b.create<math::LogOp>(opType, absA);
-
-  // Compute b * ln(|a|)
   Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
-
-  // Compute exp(b * ln(|a|))
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value logSign = b.create<math::LogOp>(opType, signA);
-  Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
-  Value signPow = b.create<math::ExpOp>(opType, signMult);
-
-  Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
+  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+  Value isOdd = b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+  Value signedExpResult = b.create<arith::SelectOp>(op->getLoc(), isOdd,
+      b.create<arith::MulFOp>(opType, expResult, signA), expResult);
 
-  // First, we select between the exp value and the adjusted value for odd
-  // powers of negatives. Then, we ensure that one is produced if `b` is zero.
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
   Value zeroCheck =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
   Value finalResult =
-      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
+  // Replace the original operation
   rewriter.replaceOp(op, finalResult);
   return success();
 }
 
+
 // exp2f(float x) -> exp(x * ln(2))
 //   Proof: Let's say 2^x = y
 //   ln(2^x) = ln(y)

>From 01fa2c1dcdc097a89388bbdeaef789404641d2c1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:38:10 +0900
Subject: [PATCH 4/7] clear

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index f2b32a25316c23..b32790db4509ef 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,7 +317,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
-
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
@@ -340,7 +339,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
   Value finalResult =
       b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
-  // Replace the original operation
   rewriter.replaceOp(op, finalResult);
   return success();
 }

>From 9fa470e11f1156e96ffd57927d3401d7388bec80 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:38:41 +0900
Subject: [PATCH 5/7] clear

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index b32790db4509ef..902dc04bce9e65 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -323,7 +323,8 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
 
   Value absA = b.create<math::AbsFOp>(opType, operandA);
-  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value isNegative =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
   Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
   Value logA = b.create<math::LogOp>(opType, absA);
   Value mult = b.create<arith::MulFOp>(opType, operandB, logA);

>From 05552f3c58cc196e03b3ee7ff5b6495bb6c28968 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:45:58 +0900
Subject: [PATCH 6/7] fix

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 902dc04bce9e65..94296dc32dfe56 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -322,7 +322,8 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
 
-  Value absA = b.create<math::AbsFOp>(opType, operandA);
+  Value absA =
+      b.create<math::AbsFOp>(opType, operandA);
   Value isNegative =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
   Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
@@ -330,9 +331,11 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
   Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value isOdd = b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value signedExpResult = b.create<arith::SelectOp>(op->getLoc(), isOdd,
-      b.create<arith::MulFOp>(opType, expResult, signA), expResult);
+  Value isOdd =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+  Value signedExpResult = b.create<arith::SelectOp>(
+      op->getLoc(), isOdd,b.create<arith::MulFOp>(opType, expResult, signA),
+      expResult);
 
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
@@ -344,7 +347,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   return success();
 }
 
-
 // exp2f(float x) -> exp(x * ln(2))
 //   Proof: Let's say 2^x = y
 //   ln(2^x) = ln(y)

>From 0ff54876891fc221944b72b073e590c867ed2557 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sun, 26 Jan 2025 08:43:43 +0900
Subject: [PATCH 7/7] fix tests

---
 mlir/test/Dialect/Math/expand-math.mlir | 88 ++++++++++++-------------
 1 file changed, 42 insertions(+), 46 deletions(-)

diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84c..a6c6e51ab88e25 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,24 +202,23 @@ func.func @roundf_func(%a: f32) -> f32 {
 
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
   // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
-  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
-  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
-  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
-  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
-  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
-  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
-  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+  // CHECK-DAG: [[CSTNEG1:%.+]] = arith.constant -1.000000e+00
+  // CHECK-DAG: [[CSTTWO:%.+]] = arith.constant 2.000000e+00
+  // CHECK: [[ABSA:%.+]] = math.absf [[ARG0]]
+  // CHECK: [[ISNEG:%.+]] = arith.cmpf olt, [[ARG0]], [[CST0]]
+  // CHECK: [[SIGNA:%.+]] = arith.select [[ISNEG]], [[CSTNEG1]], [[CST1]]
+  // CHECK: [[LOGA:%.+]] = math.log [[ABSA]]
+  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
+  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
+  // CHECK: [[REM:%.+]] = arith.remf [[ARG1]], [[CSTTWO]]
+  // CHECK: [[CMPF:%.+]] = arith.cmpf one, [[REM]], [[CST0]]
+  // CHECK: [[ABMUL:%.+]] = arith.mulf [[EXP]], [[SIGNA]]
+  // CHECK: [[SEL0:%.+]] = arith.select [[CMPF]], [[ABMUL]], [[EXP]]
+  // CHECK: [[CMPF2:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+// CHECK: [[SEL1:%.+]] = arith.select [[CMPF2]], [[CST1]], [[SEL0]]
   // CHECK: return [[SEL1]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
@@ -602,26 +601,24 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:      return %[[SEL1]] : tensor<8xf32>
-
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : tensor<8xf32>
+// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : tensor<8xf32>
+// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : tensor<8xf32>
+// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL1]]
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -635,19 +632,18 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
-// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK:        %[[ABSA:.*]] = math.absf %[[ARG0]] : f32
+// CHECK:        %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK:        %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : f32
+// CHECK:        %[[LOGA:.*]] = math.log %[[ABSA]] : f32
+// CHECK:        %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK:        %[[EXPA:.*]] = math.exp %[[MULA]] : f32
 // CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:        %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK:        %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : f32
+// CHECK:        %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : f32
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK:        %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : f32
 // CHECK:       return %[[SEL1]] : f32
 
 // -----



More information about the Mlir-commits mailing list