[Mlir-commits] [mlir] 10a57f3 - [mlir][math] Expand powfI operation for constant power operand. (#87081)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 1 00:48:31 PDT 2024

Author: Prashant Kumar
Date: 2024-04-01T13:18:27+05:30
New Revision: 10a57f3aff34be6ab43106dc1e45ace3f6da881c

URL: https://github.com/llvm/llvm-project/commit/10a57f3aff34be6ab43106dc1e45ace3f6da881c
DIFF: https://github.com/llvm/llvm-project/commit/10a57f3aff34be6ab43106dc1e45ace3f6da881c.diff

LOG: [mlir][math] Expand powfI operation for constant power operand. (#87081)

-- Convert `math.fpowi` to a series of `arith.mulf` operations.
-- If the power is negative, we divide the result by 1.




diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 11b2c7a7afa2f7..e2c513047c77a5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
 void populateExpandPowFPattern(RewritePatternSet &patterns);
+void populateExpandFPowIPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index e1ab9c905447b7..0b8546251350fa 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -1,4 +1,4 @@
-//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
+//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
-// This file implements expansion of tanh op.
+// This file implements expansion of various math operations.
@@ -23,9 +23,14 @@
 using namespace mlir;
 /// Create a float constant.
-static Value createFloatConst(Location loc, Type type, double value,
+static Value createFloatConst(Location loc, Type type, APFloat value,
                               OpBuilder &b) {
-  auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
+  bool losesInfo = false;
+  auto eltType = getElementTypeOrSelf(type);
+  // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
+  value.convert(cast<FloatType>(eltType).getFloatSemantics(),
+                APFloat::rmNearestTiesToEven, &losesInfo);
+  auto attr = b.getFloatAttr(eltType, value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     return b.create<arith::ConstantOp>(loc,
                                        DenseElementsAttr::get(shapedTy, attr));
@@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value,
   return b.create<arith::ConstantOp>(loc, attr);
-/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, double value,
+                              OpBuilder &b) {
+  return createFloatConst(loc, type, APFloat(value), b);
+/// Create an integer constant.
 static Value createIntConst(Location loc, Type type, int64_t value,
                             OpBuilder &b) {
   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
@@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   rewriter.replaceOp(op, ret);
   return success();
+// Convert `math.fpowi` to a series of `arith.mulf` operations.
+// If the power is negative, we divide one by the result.
+// If both the base and power are zero, the result is 1.
+static LogicalResult convertFPowICstOp(math::FPowIOp op,
+                                       PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value base = op.getOperand(0);
+  Value power = op.getOperand(1);
+  Type baseType = base.getType();
+  Attribute cstAttr;
+  if (!matchPattern(power, m_Constant(&cstAttr)))
+    return failure();
+  APInt value;
+  if (!matchPattern(cstAttr, m_ConstantInt(&value)))
+    return failure();
+  int64_t powerInt = value.getSExtValue();
+  bool isNegative = powerInt < 0;
+  int64_t absPower = std::abs(powerInt);
+  Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
+  Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
+  while (absPower > 0) {
+    if (absPower & 1)
+      res = b.create<arith::MulFOp>(baseType, base, res);
+    absPower >>= 1;
+    base = b.create<arith::MulFOp>(baseType, base, base);
+  }
+  // Make sure not to introduce UB in case of negative power.
+  if (isNegative) {
+    auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+                    .getFloatSemantics();
+    Value zero =
+        createFloatConst(op->getLoc(), baseType,
+                         APFloat::getZero(sem, /*Negative=*/false), rewriter);
+    Value negZero =
+        createFloatConst(op->getLoc(), baseType,
+                         APFloat::getZero(sem, /*Negative=*/true), rewriter);
+    Value posInfinity =
+        createFloatConst(op->getLoc(), baseType,
+                         APFloat::getInf(sem, /*Negative=*/false), rewriter);
+    Value negInfinity =
+        createFloatConst(op->getLoc(), baseType,
+                         APFloat::getInf(sem, /*Negative=*/true), rewriter);
+    Value zeroEqCheck =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+    Value negZeroEqCheck =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
+    res = b.create<arith::DivFOp>(baseType, one, res);
+    res =
+        b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
+    res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
+                                    res);
+  }
+  rewriter.replaceOp(op, res);
+  return success();
 // Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +590,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
+  patterns.add(convertFPowICstOp);
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6326d3a71874b4..bfcff27bd64eb0 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -511,3 +511,102 @@ func.func @roundeven16(%arg: f16) -> f16 {
 // CHECK:     %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
 // CHECK: return %[[COPYSIGN]] : f16
+// -----
+// CHECK-LABEL:   func.func @math_fpowi_neg_odd_power
+func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<-3> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK-DAG:    %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+//  CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+//  CHECK-DAG:    %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
+//  CHECK-DAG:    %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
+//  CHECK-DAG:    %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[CMP0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CST0]] : tensor<8xf32>
+//  CHECK:        %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CSTNEG0]] : tensor<8xf32>
+//  CHECK:        %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
+//  CHECK:        %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
+//  CHECK:        %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
+//  CHECK:      return %[[UB2]] : tensor<8xf32>
+// -----
+// CHECK-LABEL:   func.func @math_fpowi_neg_even_power
+func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<-4> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK-DAG:    %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+//  CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+//  CHECK-DAG:    %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
+//  CHECK-DAG:    %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
+//  CHECK-DAG:    %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:        %[[CMP0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CST0]] : tensor<8xf32>
+//  CHECK:        %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CSTNEG0]] : tensor<8xf32>
+//  CHECK:        %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
+//  CHECK:        %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
+//  CHECK:        %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
+//  CHECK:      return %[[UB2]] : tensor<8xf32>
+// -----
+// CHECK-LABEL:   func.func @math_fpowi_pos_odd_power
+func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<5> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:        %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:      return %[[PW5]] : tensor<8xf32>
+// -----
+// CHECK-LABEL:   func.func @math_fpowi_pos_even_power
+func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<4> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:      return %[[PW4]] : tensor<8xf32>
+// -----
+// CHECK-LABEL:   func.func @math_fpowi_even_scalar
+func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
+  %pow = arith.constant 2 : i64
+  %2 = math.fpowi %0, %pow : f32, i64
+  return %2 : f32
+//  CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
+//  CHECK:         %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+//  CHECK:         return %[[SQ]] : f32
+// -----
+// CHECK-LABEL:   func.func @math_fpowi_scalar_zero
+func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
+  %pow = arith.constant 0 : i64
+  %2 = math.fpowi %0, %pow : f32, i64
+  return %2 : f32
+//  CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
+//  CHECK:         %[[RET:.*]] = arith.constant 1.000000e+00 : f32
+//  CHECK:         return %[[RET]] : f32
+// -----

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 7ce8b5a7cfe9b3..97600ad1ebe7a3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
+  populateExpandFPowIPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));


More information about the Mlir-commits mailing list