[Mlir-commits] [mlir] [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (PR #124402)
Hyunsung Lee
llvmlistbot at llvm.org
Sat Jan 25 04:46:42 PST 2025
https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/124402
>From 5e65430b11da9d739b8b0bd82a0d51bdaaefc6cd Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:01:11 +0900
Subject: [PATCH 1/6] Update ExpandPatterns.cpp
---
.../Math/Transforms/ExpandPatterns.cpp | 45 +++++++++++--------
1 file changed, 27 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f6..314b5b30202064 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,34 +317,43 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
+
+ // Constants
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
- Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
- Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
+ Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+
+ // Compute |a| (absolute value of operandA)
+ Value absA = b.create<math::AbsFOp>(opType, operandA);
+
+ // Compute sign(a) as -1.0 if a < 0, else 1.0
+ Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
- Value logA = b.create<math::LogOp>(opType, opASquared);
- Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+ // Compute ln(|a|)
+ Value logA = b.create<math::LogOp>(opType, absA);
+
+ // Compute b * ln(|a|)
+ Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+
+ // Compute exp(b * ln(|a|))
Value expResult = b.create<math::ExpOp>(opType, mult);
- Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
- Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value negCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value oddPower =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+ Value logSign = b.create<math::LogOp>(opType, signA);
+ Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
+ Value signPow = b.create<math::ExpOp>(opType, signMult);
+
+ Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
- Value zeroCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
- expResult);
- res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
- rewriter.replaceOp(op, res);
+ Value zeroCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+ Value finalResult =
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+ rewriter.replaceOp(op, finalResult);
return success();
}
>From ce55a3fb40b2b010aba6486f698548c346924c26 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 20:16:39 +0900
Subject: [PATCH 2/6] formatting
---
.../Dialect/Math/Transforms/ExpandPatterns.cpp | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 314b5b30202064..cdfafc2db52535 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -318,7 +318,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
- // Constants
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
@@ -328,8 +327,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value absA = b.create<math::AbsFOp>(opType, operandA);
// Compute sign(a) as -1.0 if a < 0, else 1.0
- Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
+ Value isNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value signA =
+ b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
// Compute ln(|a|)
Value logA = b.create<math::LogOp>(opType, absA);
@@ -349,10 +350,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
- Value zeroCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value finalResult =
- b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+ Value zeroCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+ Value finalResult =
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
rewriter.replaceOp(op, finalResult);
return success();
}
>From 20ce34ecd08dffc3f98d53a55c8d338fa96c025f Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:36:57 +0900
Subject: [PATCH 3/6] .
---
.../Math/Transforms/ExpandPatterns.cpp | 32 ++++++-------------
1 file changed, 10 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index cdfafc2db52535..f2b32a25316c23 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,7 +311,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|)) * sign(a)^b
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -323,41 +323,29 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
- // Compute |a| (absolute value of operandA)
Value absA = b.create<math::AbsFOp>(opType, operandA);
-
- // Compute sign(a) as -1.0 if a < 0, else 1.0
- Value isNegative =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value signA =
- b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
-
- // Compute ln(|a|)
+ Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
Value logA = b.create<math::LogOp>(opType, absA);
-
- // Compute b * ln(|a|)
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
-
- // Compute exp(b * ln(|a|))
Value expResult = b.create<math::ExpOp>(opType, mult);
- Value logSign = b.create<math::LogOp>(opType, signA);
- Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
- Value signPow = b.create<math::ExpOp>(opType, signMult);
-
- Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
+ Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+ Value isOdd = b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+ Value signedExpResult = b.create<arith::SelectOp>(op->getLoc(), isOdd,
+ b.create<arith::MulFOp>(opType, expResult, signA), expResult);
- // First, we select between the exp value and the adjusted value for odd
- // powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value finalResult =
- b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
+ // Replace the original operation
rewriter.replaceOp(op, finalResult);
return success();
}
+
// exp2f(float x) -> exp(x * ln(2))
// Proof: Let's say 2^x = y
// ln(2^x) = ln(y)
>From 01fa2c1dcdc097a89388bbdeaef789404641d2c1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:38:10 +0900
Subject: [PATCH 4/6] clear
---
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index f2b32a25316c23..b32790db4509ef 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,7 +317,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
-
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
@@ -340,7 +339,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value finalResult =
b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
- // Replace the original operation
rewriter.replaceOp(op, finalResult);
return success();
}
>From 9fa470e11f1156e96ffd57927d3401d7388bec80 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:38:41 +0900
Subject: [PATCH 5/6] clear
---
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index b32790db4509ef..902dc04bce9e65 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -323,7 +323,8 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value absA = b.create<math::AbsFOp>(opType, operandA);
- Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value isNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
Value logA = b.create<math::LogOp>(opType, absA);
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
>From 05552f3c58cc196e03b3ee7ff5b6495bb6c28968 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 25 Jan 2025 21:45:58 +0900
Subject: [PATCH 6/6] fix
---
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 902dc04bce9e65..94296dc32dfe56 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -322,7 +322,8 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
- Value absA = b.create<math::AbsFOp>(opType, operandA);
+ Value absA =
+ b.create<math::AbsFOp>(opType, operandA);
Value isNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
@@ -330,9 +331,11 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value isOdd = b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value signedExpResult = b.create<arith::SelectOp>(op->getLoc(), isOdd,
- b.create<arith::MulFOp>(opType, expResult, signA), expResult);
+ Value isOdd =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+ Value signedExpResult = b.create<arith::SelectOp>(
+ op->getLoc(), isOdd,b.create<arith::MulFOp>(opType, expResult, signA),
+ expResult);
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
@@ -344,7 +347,6 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
return success();
}
-
// exp2f(float x) -> exp(x * ln(2))
// Proof: Let's say 2^x = y
// ln(2^x) = ln(y)
More information about the Mlir-commits
mailing list