[Mlir-commits] [mlir] [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (PR #124402)

Hyunsung Lee llvmlistbot at llvm.org
Sat Jan 25 03:17:45 PST 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/124402

>From 5e65430b11da9d739b8b0bd82a0d51bdaaefc6cd Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:01:11 +0900
Subject: [PATCH 1/2] Update ExpandPatterns.cpp

---
 .../Math/Transforms/ExpandPatterns.cpp        | 45 +++++++++++--------
 1 file changed, 27 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f6..314b5b30202064 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,34 +317,43 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
+
+  // Constants
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
-  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
-  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
+  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+
+  // Compute |a| (absolute value of operandA)
+  Value absA = b.create<math::AbsFOp>(opType, operandA);
+
+  // Compute sign(a) as -1.0 if a < 0, else 1.0
+  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
 
-  Value logA = b.create<math::LogOp>(opType, opASquared);
-  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+  // Compute ln(|a|)
+  Value logA = b.create<math::LogOp>(opType, absA);
+
+  // Compute b * ln(|a|)
+  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+
+  // Compute exp(b * ln(|a|))
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
-  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value negCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value oddPower =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+  Value logSign = b.create<math::LogOp>(opType, signA);
+  Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
+  Value signPow = b.create<math::ExpOp>(opType, signMult);
+
+  Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
-  Value zeroCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
-                                        expResult);
-  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
-  rewriter.replaceOp(op, res);
+  Value zeroCheck = 
+    b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+  Value finalResult = 
+    b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+  rewriter.replaceOp(op, finalResult);
   return success();
 }
 

>From ce55a3fb40b2b010aba6486f698548c346924c26 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:16:39 +0900
Subject: [PATCH 2/2] formatting

---
 .../Dialect/Math/Transforms/ExpandPatterns.cpp    | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 314b5b30202064..cdfafc2db52535 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -318,7 +318,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
 
-  // Constants
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
@@ -328,8 +327,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value absA = b.create<math::AbsFOp>(opType, operandA);
 
   // Compute sign(a) as -1.0 if a < 0, else 1.0
-  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
+  Value isNegative =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value signA =
+      b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
 
   // Compute ln(|a|)
   Value logA = b.create<math::LogOp>(opType, absA);
@@ -349,10 +350,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
-  Value zeroCheck = 
-    b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value finalResult = 
-    b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+  Value zeroCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+  Value finalResult =
+      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
   rewriter.replaceOp(op, finalResult);
   return success();
 }



More information about the Mlir-commits mailing list