[Mlir-commits] [mlir] [mlir][math] Expand powfI operation for constant power operand. (PR #87081)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 29 08:48:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: Prashant Kumar (pashu123)

<details>
<summary>Changes</summary>

-- Convert `math.fpowi` to a series of `arith.mulf` operations.
-- If the power is negative, we divide the result by 1.

---
Full diff: https://github.com/llvm/llvm-project/pull/87081.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+1) 
- (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+46) 
- (modified) mlir/test/Dialect/Math/expand-math.mlir (+57) 
- (modified) mlir/test/lib/Dialect/Math/TestExpandMath.cpp (+1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 11b2c7a7afa2f7..e2c513047c77a5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
 void populateExpandPowFPattern(RewritePatternSet &patterns);
+void populateExpandFPowIPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index e1ab9c905447b7..82947d064c9fff 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -202,6 +202,48 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   rewriter.replaceOp(op, ret);
   return success();
 }
+
+// Convert `math.fpowi` to a series of `arith.mulf` operations.
+// If the power is negative, we divide the result by 1.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  Type opType = operandA.getType();
+  auto conOp =
+      mlir::dyn_cast<mlir::arith::ConstantOp>(operandB.getDefiningOp());
+
+  if (!conOp)
+    return failure();
+
+  auto iAttr = dyn_cast<mlir::SplatElementsAttr>(conOp.getValue());
+
+  if (!iAttr)
+    return failure();
+
+  int64_t power = iAttr.getSplatValue<int64_t>();
+  bool neg = power < 0;
+  int64_t absPower = std::abs(power);
+  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+  Value res = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+
+  while (absPower > 0) {
+
+    if (absPower & 1)
+      res = b.create<arith::MulFOp>(opType, operandA, res);
+
+    absPower = absPower >> 1;
+    operandA = b.create<arith::MulFOp>(opType, operandA, operandA);
+  }
+
+  if (neg)
+    res = b.create<arith::DivFOp>(opType, one, res);
+
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
 // Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +559,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
   patterns.add(convertPowfOp);
 }
 
+void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
+  patterns.add(convertFPowIOp);
+}
+
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundOp);
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6326d3a71874b4..01dfe4783cfb69 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -511,3 +511,60 @@ func.func @roundeven16(%arg: f16) -> f16 {
 // CHECK:     %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
 
 // CHECK: return %[[COPYSIGN]] : f16
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_neg_odd_power
+func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<-3> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
+//  CHECK:      return %[[INV]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_neg_even_power
+func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<-4> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:        %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
+//  CHECK:      return %[[INV]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_pos_odd_power
+func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<5> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:        %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:      return %[[PW5]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_pos_even_power
+func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<4> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:      return %[[PW4]] : tensor<8xf32>
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 7ce8b5a7cfe9b3..97600ad1ebe7a3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
   populateExpandPowFPattern(patterns);
+  populateExpandFPowIPattern(patterns);
   populateExpandRoundFPattern(patterns);
   populateExpandRoundEvenPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

``````````

</details>


https://github.com/llvm/llvm-project/pull/87081


More information about the Mlir-commits mailing list