[Mlir-commits] [mlir] 5b702be - [mlir][math] Convert math.fpowi to math.powf in case of non constant (#87472)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 3 09:49:30 PDT 2024


Author: Prashant Kumar
Date: 2024-04-03T22:19:26+05:30
New Revision: 5b702be1e80b8733786ac48ceaf04f2936616d1b

URL: https://github.com/llvm/llvm-project/commit/5b702be1e80b8733786ac48ceaf04f2936616d1b
DIFF: https://github.com/llvm/llvm-project/commit/5b702be1e80b8733786ac48ceaf04f2936616d1b.diff

LOG: [mlir][math] Convert math.fpowi to math.powf in case of non constant (#87472)

Convert math.fpowi to math.powf by converting dtype of power operand to
floating point.

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 0b8546251350fa..42629e149e9ffb 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -216,20 +216,30 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
 // Convert `math.fpowi` to a series of `arith.mulf` operations.
 // If the power is negative, we divide one by the result.
 // If both the base and power are zero, the result is 1.
-static LogicalResult convertFPowICstOp(math::FPowIOp op,
-                                       PatternRewriter &rewriter) {
+// In the case of non constant power, we convert the operation to `math.powf`.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+                                    PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value base = op.getOperand(0);
   Value power = op.getOperand(1);
   Type baseType = base.getType();
 
+  auto convertFPowItoPowf = [&]() -> LogicalResult {
+    Value castPowerToFp =
+        rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
+    Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
+                                              castPowerToFp);
+    rewriter.replaceOp(op, res);
+    return success();
+  };
+
   Attribute cstAttr;
   if (!matchPattern(power, m_Constant(&cstAttr)))
-    return failure();
+    return convertFPowItoPowf();
 
   APInt value;
   if (!matchPattern(cstAttr, m_ConstantInt(&value)))
-    return failure();
+    return convertFPowItoPowf();
 
   int64_t powerInt = value.getSExtValue();
   bool isNegative = powerInt < 0;
@@ -591,7 +601,7 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
 }
 
 void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
-  patterns.add(convertFPowICstOp);
+  patterns.add(convertFPowIOp);
 }
 
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index bfcff27bd64eb0..3d94b55126d097 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -610,3 +610,51 @@ func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
 //  CHECK:         return %[[RET]] : f32
 
 // -----
+
+// CHECK-LABEL:   func.func @math_fpowi_to_powf_tensor
+func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> tensor<8xf32> {
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi32>
+  return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
+// CHECK:        %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK:        %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
+// CHECK:        %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
+// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
+// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK:      return %[[SEL]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
+func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
+  %2 = math.fpowi %0, %1 : f32, i64
+  return %2 : f32
+}
+// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
+// CHECK:        %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK:        %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK:        %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
+// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
+// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
+// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
+// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
+// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
+// CHECK:       return %[[SEL]] : f32


        


More information about the Mlir-commits mailing list