[Mlir-commits] [mlir] [mlir][math] `powf(a, b)` drop support when a < 0 (PR #126338)

Hyunsung Lee llvmlistbot at llvm.org
Wed Feb 12 16:18:35 PST 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/126338

>From 0e790b6f4a51ba7ab3e7a805e6141108036bab0a Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 29 Jan 2025 12:56:43 +0900
Subject: [PATCH 01/12] [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp`
 (#124402)

The current implementation of `convertPowfOp` requires a calculation of
`a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will
overflow so get INF in fp8 or fp16 easily.


Remove support when `a < 0`. Overhead of handling negative value of `a`
is large and easy to overflow;

- related issue in iree:
https://github.com/iree-org/iree/issues/15936
---
 .../Math/Transforms/ExpandPatterns.cpp        | 25 ++-----
 mlir/test/Dialect/Math/expand-math.mlir       | 71 ++++++-------------
 .../mlir-runner/test-expand-math-approx.mlir  |  5 --
 3 files changed, 27 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f..30bcdfc45837a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,7 +311,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -319,21 +320,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
-  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
-  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
-  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
 
-  Value logA = b.create<math::LogOp>(opType, opASquared);
-  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+  Value logA = b.create<math::LogOp>(opType, operandA);
+  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
-  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value negCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value oddPower =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -341,10 +331,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
   Value zeroCheck =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
-                                        expResult);
-  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
-  rewriter.replaceOp(op, res);
+  Value finalResult =
+      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
+  rewriter.replaceOp(op, finalResult);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84..5b443e9e8d4e7 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 {
 
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
   // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
-  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
-  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
-  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
-  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
-  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
-  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
-  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
-  // CHECK: return [[SEL1]]
+  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
+  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
+  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
+  // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+  // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
+  // CHECK: return [[SEL]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:      return %[[SEL1]] : tensor<8xf32>
-
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL]]
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
+// CHECK:        %[[LOGA:.*]] = math.log %[[ARG0]] : f32
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
 // CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:       return %[[SEL1]] : f32
+// CHECK:        %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK:        %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
+// CHECK:       return %[[SEL]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index 106b48a2daea2..d1916c28878b9 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,11 +202,6 @@ func.func @powf() {
   %a_p = arith.constant 2.0 : f64
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
-  // CHECK-NEXT: -27
-  %b   = arith.constant -3.0 : f64
-  %b_p = arith.constant 3.0 : f64
-  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
-
   // CHECK-NEXT: 2.343
   %c   = arith.constant 2.343 : f64
   %c_p = arith.constant 1.000 : f64

>From b248c2275c9d499695b3d63a96e65fcce88e9689 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 8 Feb 2025 12:51:18 +0900
Subject: [PATCH 02/12] add special cases for handling powf

---
 .../Math/Transforms/ExpandPatterns.cpp        | 90 +++++++++++++++++++
 1 file changed, 90 insertions(+)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 30bcdfc45837a..235ea38dd87d1 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -17,8 +17,13 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cmath>
 
 using namespace mlir;
 
@@ -311,6 +316,90 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+                                          PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  auto baseType = operandB.getType();
+
+  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+                  .getFloatSemantics();
+
+  auto valueB = APFloat(sem);
+  if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+    // Not a constant, return failure
+    return failure();
+  }
+  float floatValueB = valueB.convertToFloat();
+
+  if (floatValueB == 1.0f) {
+    // a^1 -> a
+    rewriter.replaceOp(op, operandA);
+    return success();
+  }
+
+  if (floatValueB == 0.0) {
+    // a^0 -> 1
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    rewriter.replaceOp(op, one);
+    return success();
+  }
+
+  if (floatValueB == 0.5f) {
+    // a^(1/2) -> sqrt(a)
+    Value sqrt = b.create<math::SqrtOp>(operandA);
+    rewriter.replaceOp(op, sqrt);
+    return success();
+  }
+
+  if (floatValueB == -0.5f) {
+    // a^(-1/2) -> 1 / sqrt(a)
+    Value rsqrt = b.create<math::RsqrtOp>(operandA);
+    rewriter.replaceOp(op, rsqrt);
+    return success();
+  }
+
+  if (floatValueB == -1.0f) {
+    // a^(-1) -> 1 / a
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    Value div = b.create<arith::DivFOp>(one, operandA);
+    rewriter.replaceOp(op, div);
+    return success();
+  }
+
+  // Check if the power is an integer
+  if (floatValueB != std::floor(floatValueB)) {
+    // We don't handle non-integer powers here, return failure
+    return failure();
+  }
+
+  auto sign = std::signbit(floatValueB) ? -1 : 1;
+  auto absIntValueB = std::abs(static_cast<int>(floatValueB));
+
+  auto cstOne =
+      createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+  auto base = operandA;
+  if (sign == -1) {
+    base = b.create<arith::DivFOp>(cstOne, base);
+  }
+  auto current = base;
+  auto result = cstOne;
+  while (absIntValueB > 0) {
+    if (absIntValueB & 1) {
+      result = b.create<arith::MulFOp>(result, current);
+    }
+    current = b.create<arith::MulFOp>(current, current);
+    absIntValueB >>= 1;
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
 // Restricting a >= 0
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
@@ -649,6 +738,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
 }
 
 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertSpecialPowfOp);
   patterns.add(convertPowfOp);
 }
 

>From 0e7dc199d7ee765ded899c753c12724ae21db96e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 11 Feb 2025 11:57:58 +0900
Subject: [PATCH 03/12] add test

---
 .../Math/Transforms/ExpandPatterns.cpp        | 89 ++++++-------------
 mlir/test/Dialect/Math/expand-math.mlir       | 24 ++---
 .../mlir-runner/test-expand-math-approx.mlir  | 64 +++++++------
 3 files changed, 73 insertions(+), 104 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 235ea38dd87d1..9ad1ac2308838 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -9,7 +9,6 @@
 // This file implements expansion of various math operations.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
@@ -316,13 +315,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Convert Powf(float a, float b) for some special cases
-// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+// Convert Powf(float a, float b) for special cases when b is constant:
+// when b == 0, or |b| == 0.5, 1.0, or 2.0.
 static LogicalResult convertSpecialPowfOp(math::PowFOp op,
                                           PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
+  auto opType = operandA.getType();
   auto baseType = operandB.getType();
 
   auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
@@ -334,95 +334,64 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op,
     return failure();
   }
   float floatValueB = valueB.convertToFloat();
-
+  if (floatValueB == 0.0f) {
+    // a^0 -> 1
+    Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+    rewriter.replaceOp(op, one);
+    return success();
+  }
   if (floatValueB == 1.0f) {
     // a^1 -> a
     rewriter.replaceOp(op, operandA);
     return success();
   }
-
-  if (floatValueB == 0.0) {
-    // a^0 -> 1
-    Value one =
-        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-    rewriter.replaceOp(op, one);
+  if (floatValueB == -1.0f) {
+    // a^(-1) -> 1 / a
+    Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+    Value div = b.create<arith::DivFOp>(one, operandA);
+    rewriter.replaceOp(op, div);
     return success();
   }
-
   if (floatValueB == 0.5f) {
     // a^(1/2) -> sqrt(a)
     Value sqrt = b.create<math::SqrtOp>(operandA);
     rewriter.replaceOp(op, sqrt);
     return success();
   }
-
   if (floatValueB == -0.5f) {
     // a^(-1/2) -> 1 / sqrt(a)
     Value rsqrt = b.create<math::RsqrtOp>(operandA);
     rewriter.replaceOp(op, rsqrt);
     return success();
   }
-
-  if (floatValueB == -1.0f) {
-    // a^(-1) -> 1 / a
+  if (floatValueB == 2.0f) {
+    // a^2 -> a * a
+    Value mul = b.create<arith::MulFOp>(operandA, operandA);
+    rewriter.replaceOp(op, mul);
+    return success();
+  }
+  if (floatValueB == -2.0f) {
+    // a^(-2) -> 1 / (a * a)
+    Value mul = b.create<arith::MulFOp>(operandA, operandA);
     Value one =
         createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-    Value div = b.create<arith::DivFOp>(one, operandA);
+    Value div = b.create<arith::DivFOp>(one, mul);
     rewriter.replaceOp(op, div);
     return success();
   }
 
-  // Check if the power is an integer
-  if (floatValueB != std::floor(floatValueB)) {
-    // We don't handle non-integer powers here, return failure
-    return failure();
-  }
-
-  auto sign = std::signbit(floatValueB) ? -1 : 1;
-  auto absIntValueB = std::abs(static_cast<int>(floatValueB));
-
-  auto cstOne =
-      createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-  auto base = operandA;
-  if (sign == -1) {
-    base = b.create<arith::DivFOp>(cstOne, base);
-  }
-  auto current = base;
-  auto result = cstOne;
-  while (absIntValueB > 0) {
-    if (absIntValueB & 1) {
-      result = b.create<arith::MulFOp>(result, current);
-    }
-    current = b.create<arith::MulFOp>(current, current);
-    absIntValueB >>= 1;
-  }
-  rewriter.replaceOp(op, result);
-  return success();
+  return failure();
 }
 
 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
-// Restricting a >= 0
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
-  Type opType = operandA.getType();
-  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
-  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-
-  Value logA = b.create<math::LogOp>(opType, operandA);
-  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
-  Value expResult = b.create<math::ExpOp>(opType, mult);
-
-  // First, we select between the exp value and the adjusted value for odd
-  // powers of negatives. Then, we ensure that one is produced if `b` is zero.
-  // This corresponds to `libm` behavior, even for `0^0`. Without this check,
-  // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
-  Value zeroCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value finalResult =
-      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
-  rewriter.replaceOp(op, finalResult);
+  Value logA = b.create<math::LogOp>(operandA);
+  Value mult = b.create<arith::MulFOp>(operandB, logA);
+  Value expResult = b.create<math::ExpOp>(mult);
+  rewriter.replaceOp(op, expResult);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 5b443e9e8d4e7..3cf372ea0cf50 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -203,14 +203,10 @@ func.func @roundf_func(%a: f32) -> f32 {
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
 func.func @powf_func(%a: f64, %b: f64) -> f64 {
-  // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
-  // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
-  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
-  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
-  // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
-  // CHECK: return [[SEL]]
+  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]] : f64
+  // CHECK: [[MUL:%.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64
+  // CHECK: [[EXP:%.+]] = math.exp [[MUL]] : f64
+  // CHECK: return [[EXP]] : f64
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -592,15 +588,11 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
 // CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
 // CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
 // CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
 // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]]
+// CHECK: return %[[EXP]] : tensor<8xf32>
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -609,15 +601,11 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
 // CHECK:        %[[LOGA:.*]] = math.log %[[ARG0]] : f32
 // CHECK:        %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
 // CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
-// CHECK:        %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
-// CHECK:       return %[[SEL]] : f32
+// CHECK:       return %[[EXP]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index d1916c28878b9..b599c9d8435d4 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -203,49 +203,61 @@ func.func @powf() {
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: 2.343
-  %c   = arith.constant 2.343 : f64
-  %c_p = arith.constant 1.000 : f64
-  call @func_powff64(%c, %c_p) : (f64, f64) -> ()
+  %b   = arith.constant 2.343 : f64
+  %b_p = arith.constant 1.000 : f64
+  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: 0.176171
-  %d   = arith.constant 4.25 : f64
-  %d_p = arith.constant -1.2  : f64
-  call @func_powff64(%d, %d_p) : (f64, f64) -> ()
+  %c   = arith.constant 4.25 : f64
+  %c_p = arith.constant -1.2  : f64
+  call @func_powff64(%c, %c_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: 1
-  %e   = arith.constant 4.385 : f64
-  %e_p = arith.constant 0.00 : f64
-  call @func_powff64(%e, %e_p) : (f64, f64) -> ()
+  %d   = arith.constant 4.385 : f64
+  %d_p = arith.constant 0.00 : f64
+  call @func_powff64(%d, %d_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: 6.62637
-  %f    = arith.constant 4.835 : f64
-  %f_p  = arith.constant 1.2 : f64
-  call @func_powff64(%f, %f_p) : (f64, f64) -> ()
+  %e    = arith.constant 4.835 : f64
+  %e_p  = arith.constant 1.2 : f64
+  call @func_powff64(%e, %e_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: nan
-  %i = arith.constant 1.0 : f64
-  %h = arith.constant 0x7fffffffffffffff : f64
-  call @func_powff64(%i, %h) : (f64, f64) -> ()
+  %f = arith.constant 1.0 : f64
+  %f_p = arith.constant 0x7fffffffffffffff : f64
+  call @func_powff64(%f, %f_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: inf
-  %j   = arith.constant 29385.0 : f64
-  %j_p = arith.constant 23598.0 : f64
-  call @func_powff64(%j, %j_p) : (f64, f64) -> ()
+  %g   = arith.constant 29385.0 : f64
+  %g_p = arith.constant 23598.0 : f64
+  call @func_powff64(%g, %g_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: -nan
-  %k = arith.constant 1.0 : f64
-  %k_p = arith.constant 0xfff0000001000000 : f64
-  call @func_powff64(%k, %k_p) : (f64, f64) -> ()
+  %h = arith.constant 1.0 : f64
+  %h_p = arith.constant 0xfff0000001000000 : f64
+  call @func_powff64(%h, %h_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: -nan
-  %l = arith.constant 1.0 : f32
-  %l_p = arith.constant 0xffffffff : f32
-  call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+  %i = arith.constant 1.0 : f32
+  %i_p = arith.constant 0xffffffff : f32
+  call @func_powff32(%i, %i_p) : (f32, f32) -> ()
 
   // CHECK-NEXT: 1
-  %zero = arith.constant 0.0 : f32
-  call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+  %j = arith.constant 0.000 : f32
+  %j_r = math.powf %j, %j : f32
+  vector.print %j_r : f32
 
+  // CHECK-NEXT: 4
+  %k = arith.constant -2.0 : f32
+  %k_p = arith.constant 2.0 : f32
+  %k_r = math.powf %k, %k_p : f32
+  vector.print %k_r : f32
+
+  // CHECK-NEXT: 0.25
+  %l = arith.constant -2.0 : f32
+  %l_p = arith.constant -2.0 : f32
+  %l_r = math.powf %k, %l_p : f32
+  vector.print %l_r : f32
   return
 }
 

>From c52ba9fc6aa11743140fe0a6402ec93d68e3aed9 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 11 Feb 2025 11:59:43 +0900
Subject: [PATCH 04/12] formatting

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 9ad1ac2308838..5d9b264fd0d51 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -9,6 +9,7 @@
 // This file implements expansion of various math operations.
 //
 //===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
@@ -16,13 +17,8 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/APFloat.h"
-#include "llvm/Support/LogicalResult.h"
-#include <cmath>
 
 using namespace mlir;
 

>From e1e06ec257ded2ae2441556cf72b7d3732cba6c9 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 12 Feb 2025 11:26:33 +0900
Subject: [PATCH 05/12] Give explicit benefit to convertSpecialPowfOp

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 5d9b264fd0d51..34ba98ca16a3e 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -703,7 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
 }
 
 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertSpecialPowfOp);
+  patterns.add(convertSpecialPowfOp, /*benefit=*/ 2);
   patterns.add(convertPowfOp);
 }
 

>From 9cf3d3bb2e4c387d53779460a935af76c4f3d94f Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 12 Feb 2025 11:33:30 +0900
Subject: [PATCH 06/12] formatting

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 34ba98ca16a3e..704001870601b 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -703,7 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
 }
 
 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertSpecialPowfOp, /*benefit=*/ 2);
+  patterns.add(convertSpecialPowfOp, /*benefit=*/2);
   patterns.add(convertPowfOp);
 }
 

>From ec9f1287c4f466a342797bc84c455299877efa49 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 12 Feb 2025 13:01:27 +0900
Subject: [PATCH 07/12] avoid using native float

---
 .../Math/Transforms/ExpandPatterns.cpp        | 30 +++++++++----------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 704001870601b..1d27686b3acc1 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
 
 using namespace mlir;
 
@@ -318,55 +319,54 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op,
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
-  auto opType = operandA.getType();
-  auto baseType = operandB.getType();
-
-  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
-                  .getFloatSemantics();
+  auto typeA = operandA.getType();
+  auto typeB = operandB.getType();
 
+  auto &sem =
+      cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
   auto valueB = APFloat(sem);
   if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
     // Not a constant, return failure
     return failure();
   }
-  float floatValueB = valueB.convertToFloat();
-  if (floatValueB == 0.0f) {
+
+  if (valueB.compare(APFloat(0.0f)) == APFloat::cmpEqual) {
     // a^0 -> 1
-    Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+    Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
     rewriter.replaceOp(op, one);
     return success();
   }
-  if (floatValueB == 1.0f) {
+  if (valueB.compare(APFloat(1.0f)) == APFloat::cmpEqual) {
     // a^1 -> a
     rewriter.replaceOp(op, operandA);
     return success();
   }
-  if (floatValueB == -1.0f) {
+  if (valueB.compare(APFloat(-1.0f)) == APFloat::cmpEqual) {
     // a^(-1) -> 1 / a
-    Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+    Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
     Value div = b.create<arith::DivFOp>(one, operandA);
     rewriter.replaceOp(op, div);
     return success();
   }
-  if (floatValueB == 0.5f) {
+  if (valueB.compare(APFloat(0.5f)) == APFloat::cmpEqual) {
     // a^(1/2) -> sqrt(a)
     Value sqrt = b.create<math::SqrtOp>(operandA);
     rewriter.replaceOp(op, sqrt);
     return success();
   }
-  if (floatValueB == -0.5f) {
+  if (valueB.compare(APFloat(-0.5f)) == APFloat::cmpEqual) {
     // a^(-1/2) -> 1 / sqrt(a)
     Value rsqrt = b.create<math::RsqrtOp>(operandA);
     rewriter.replaceOp(op, rsqrt);
     return success();
   }
-  if (floatValueB == 2.0f) {
+  if (valueB.compare(APFloat(2.0f)) == APFloat::cmpEqual) {
     // a^2 -> a * a
     Value mul = b.create<arith::MulFOp>(operandA, operandA);
     rewriter.replaceOp(op, mul);
     return success();
   }
-  if (floatValueB == -2.0f) {
+  if (valueB.compare(APFloat(-2.0f)) == APFloat::cmpEqual) {
     // a^(-2) -> 1 / (a * a)
     Value mul = b.create<arith::MulFOp>(operandA, operandA);
     Value one =

>From 95c3d55b1803ab057deb88266ef06be67ce1a496 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 12 Feb 2025 13:20:19 +0900
Subject: [PATCH 08/12] match APFloat Types

---
 .../Math/Transforms/ExpandPatterns.cpp        | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 1d27686b3acc1..3718800437bf2 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -324,49 +324,52 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op,
 
   auto &sem =
       cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
-  auto valueB = APFloat(sem);
+  APFloat valueB(sem);
   if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
     // Not a constant, return failure
     return failure();
   }
-
-  if (valueB.compare(APFloat(0.0f)) == APFloat::cmpEqual) {
+  if (valueB.compare(APFloat::getZero(sem)) == APFloat::cmpEqual) {
     // a^0 -> 1
     Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
     rewriter.replaceOp(op, one);
     return success();
   }
-  if (valueB.compare(APFloat(1.0f)) == APFloat::cmpEqual) {
+  if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) {
     // a^1 -> a
     rewriter.replaceOp(op, operandA);
     return success();
   }
-  if (valueB.compare(APFloat(-1.0f)) == APFloat::cmpEqual) {
+  if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) {
     // a^(-1) -> 1 / a
     Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
     Value div = b.create<arith::DivFOp>(one, operandA);
     rewriter.replaceOp(op, div);
     return success();
   }
-  if (valueB.compare(APFloat(0.5f)) == APFloat::cmpEqual) {
+  APFloat halfVal(0.5);
+  halfVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr);
+  if (valueB.compare(halfVal) == APFloat::cmpEqual) {
     // a^(1/2) -> sqrt(a)
     Value sqrt = b.create<math::SqrtOp>(operandA);
     rewriter.replaceOp(op, sqrt);
     return success();
   }
-  if (valueB.compare(APFloat(-0.5f)) == APFloat::cmpEqual) {
+  if (valueB.compare(-halfVal) == APFloat::cmpEqual) {
     // a^(-1/2) -> 1 / sqrt(a)
     Value rsqrt = b.create<math::RsqrtOp>(operandA);
     rewriter.replaceOp(op, rsqrt);
     return success();
   }
-  if (valueB.compare(APFloat(2.0f)) == APFloat::cmpEqual) {
+  APFloat twoVal(2.0);
+  twoVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr);
+  if (valueB.compare(twoVal) == APFloat::cmpEqual) {
     // a^2 -> a * a
     Value mul = b.create<arith::MulFOp>(operandA, operandA);
     rewriter.replaceOp(op, mul);
     return success();
   }
-  if (valueB.compare(APFloat(-2.0f)) == APFloat::cmpEqual) {
+  if (valueB.compare(-twoVal) == APFloat::cmpEqual) {
     // a^(-2) -> 1 / (a * a)
     Value mul = b.create<arith::MulFOp>(operandA, operandA);
     Value one =

>From 79c4ef4e5f18efa74edf6aeb89a056ff6af07a0a Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 12 Feb 2025 13:22:56 +0900
Subject: [PATCH 09/12] match APFloat Types

---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3718800437bf2..8837c283f46be 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -329,7 +329,7 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op,
     // Not a constant, return failure
     return failure();
   }
-  if (valueB.compare(APFloat::getZero(sem)) == APFloat::cmpEqual) {
+  if (valueB.isZero()) {
     // a^0 -> 1
     Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
     rewriter.replaceOp(op, one);

>From 393aaa8305c5c0b86df6bfafad73549b92a1679e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 13 Feb 2025 07:52:43 +0900
Subject: [PATCH 10/12] APFloat Fix, merge two functions

---
 .../Math/Transforms/ExpandPatterns.cpp        | 113 ++++++++----------
 1 file changed, 49 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 8837c283f46be..c243ada80c8f3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -312,10 +312,10 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Convert Powf(float a, float b) for special cases when b is constant:
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Some special cases where b is constant are handled separately:
 // when b == 0, or |b| == 0.5, 1.0, or 2.0.
-static LogicalResult convertSpecialPowfOp(math::PowFOp op,
-                                          PatternRewriter &rewriter) {
+static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
@@ -325,68 +325,54 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op,
   auto &sem =
       cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
   APFloat valueB(sem);
-  if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
-    // Not a constant, return failure
-    return failure();
-  }
-  if (valueB.isZero()) {
-    // a^0 -> 1
-    Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
-    rewriter.replaceOp(op, one);
-    return success();
-  }
-  if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) {
-    // a^1 -> a
-    rewriter.replaceOp(op, operandA);
-    return success();
-  }
-  if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) {
-    // a^(-1) -> 1 / a
-    Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
-    Value div = b.create<arith::DivFOp>(one, operandA);
-    rewriter.replaceOp(op, div);
-    return success();
-  }
-  APFloat halfVal(0.5);
-  halfVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr);
-  if (valueB.compare(halfVal) == APFloat::cmpEqual) {
-    // a^(1/2) -> sqrt(a)
-    Value sqrt = b.create<math::SqrtOp>(operandA);
-    rewriter.replaceOp(op, sqrt);
-    return success();
-  }
-  if (valueB.compare(-halfVal) == APFloat::cmpEqual) {
-    // a^(-1/2) -> 1 / sqrt(a)
-    Value rsqrt = b.create<math::RsqrtOp>(operandA);
-    rewriter.replaceOp(op, rsqrt);
-    return success();
-  }
-  APFloat twoVal(2.0);
-  twoVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr);
-  if (valueB.compare(twoVal) == APFloat::cmpEqual) {
-    // a^2 -> a * a
-    Value mul = b.create<arith::MulFOp>(operandA, operandA);
-    rewriter.replaceOp(op, mul);
-    return success();
+  if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
+    if (valueB.isZero()) {
+      // a^0 -> 1
+      Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
+      rewriter.replaceOp(op, one);
+      return success();
+    }
+    if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) {
+      // a^1 -> a
+      rewriter.replaceOp(op, operandA);
+      return success();
+    }
+    if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) {
+      // a^(-1) -> 1 / a
+      Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
+      Value div = b.create<arith::DivFOp>(one, operandA);
+      rewriter.replaceOp(op, div);
+      return success();
+    }
+    if (valueB.isExactlyValue(0.5)) {
+      // a^(1/2) -> sqrt(a)
+      Value sqrt = b.create<math::SqrtOp>(operandA);
+      rewriter.replaceOp(op, sqrt);
+      return success();
+    }
+    if (valueB.isExactlyValue(-0.5)) {
+      // a^(-1/2) -> 1 / sqrt(a)
+      Value rsqrt = b.create<math::RsqrtOp>(operandA);
+      rewriter.replaceOp(op, rsqrt);
+      return success();
+    }
+    if (valueB.isExactlyValue(2.0)) {
+      // a^2 -> a * a
+      Value mul = b.create<arith::MulFOp>(operandA, operandA);
+      rewriter.replaceOp(op, mul);
+      return success();
+    }
+    if (valueB.isExactlyValue(-2.0)) {
+      // a^(-2) -> 1 / (a * a)
+      Value mul = b.create<arith::MulFOp>(operandA, operandA);
+      Value one =
+          createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+      Value div = b.create<arith::DivFOp>(one, mul);
+      rewriter.replaceOp(op, div);
+      return success();
+    }
   }
-  if (valueB.compare(-twoVal) == APFloat::cmpEqual) {
-    // a^(-2) -> 1 / (a * a)
-    Value mul = b.create<arith::MulFOp>(operandA, operandA);
-    Value one =
-        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-    Value div = b.create<arith::DivFOp>(one, mul);
-    rewriter.replaceOp(op, div);
-    return success();
-  }
-
-  return failure();
-}
 
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
-static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
-  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-  Value operandA = op.getOperand(0);
-  Value operandB = op.getOperand(1);
   Value logA = b.create<math::LogOp>(operandA);
   Value mult = b.create<arith::MulFOp>(operandB, logA);
   Value expResult = b.create<math::ExpOp>(mult);
@@ -706,7 +692,6 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
 }
 
 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertSpecialPowfOp, /*benefit=*/2);
   patterns.add(convertPowfOp);
 }
 

>From f5205a6be34865eb93705724966d8ab11d7a0587 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 13 Feb 2025 09:18:00 +0900
Subject: [PATCH 11/12] fix tests

---
 .../Math/Transforms/ExpandPatterns.cpp        |  4 +-
 mlir/test/Dialect/Math/expand-math.mlir       | 77 ++++++++++++++++++-
 2 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index c243ada80c8f3..d7953719d44b5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -332,12 +332,12 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
       rewriter.replaceOp(op, one);
       return success();
     }
-    if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) {
+    if (valueB.isExactlyValue(1.0)) {
       // a^1 -> a
       rewriter.replaceOp(op, operandA);
       return success();
     }
-    if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) {
+    if (valueB.isExactlyValue(-1.0)) {
       // a^(-1) -> 1 / a
       Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
       Value div = b.create<arith::DivFOp>(one, operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 3cf372ea0cf50..280b133926a0c 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -211,6 +211,81 @@ func.func @powf_func(%a: f64, %b: f64) -> f64 {
   return %ret : f64
 }
 
+// CHECK-LABEL:   func @powf_func_zero
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_zero(%a: f64) -> f64{
+  // CHECK: [[ONE:%.+]] = arith.constant 1.000000e+00 : f64
+  // CHECK: return [[ONE]] : f64
+  %b = arith.constant 0.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:   func @powf_func_one
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_one(%a: f64) -> f64{
+  // CHECK: return [[ARG0]] : f64
+  %b = arith.constant 1.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
+
+// CHECK-LABEL:   func @powf_func_negone
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_negone(%a: f64) -> f64{
+  // CHECK: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64
+  // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[ARG0]] : f64
+  // CHECK: return [[DIV]] : f64
+  %b = arith.constant -1.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:   func @powf_func_half
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_half(%a: f64) -> f64{
+  // CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64
+  // CHECK: return [[SQRT]] : f64
+  %b = arith.constant 0.5 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:   func @powf_func_neghalf
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_neghalf(%a: f64) -> f64{
+  // CHECK: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64
+  // CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64
+  // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64
+  // CHECK: return [[DIV]] : f64
+  %b = arith.constant -0.5 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:   func @powf_func_two
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_two(%a: f64) -> f64{
+  // CHECK: [[MUL:%.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64
+  // CHECK: return [[MUL]] : f64
+  %b = arith.constant 2.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:   func @powf_func_negtwo
+// CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
+func.func @powf_func_negtwo(%a: f64) -> f64{
+  // CHECK-DAG: [[MUL:%.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64
+  // CHECK-DAG: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64
+  // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[MUL]] : f64
+  // CHECK: return [[DIV]] : f64
+  %b = arith.constant -2.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @roundeven64
@@ -592,7 +667,7 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
 // CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
 // CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
 // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: return %[[EXP]] : tensor<8xf32>
+// CHECK: return %[[EXP]]
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar

>From d5ee522a217359264d35516a5b506149b5d1ab68 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 13 Feb 2025 09:18:17 +0900
Subject: [PATCH 12/12] fix tests

---
 mlir/test/Dialect/Math/expand-math.mlir | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 280b133926a0c..36a69ec83ce6e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -230,7 +230,6 @@ func.func @powf_func_one(%a: f64) -> f64{
   return %ret : f64
 }
 
-
 // CHECK-LABEL:   func @powf_func_negone
 // CHECK-SAME:    ([[ARG0:%.+]]: f64) -> f64
 func.func @powf_func_negone(%a: f64) -> f64{



More information about the Mlir-commits mailing list