[Mlir-commits] [mlir] 095ce65 - [mlir][math] Simplify pow(x, 0.75) into sqrt(sqrt(x)) * sqrt(x).

Slava Zakharin llvmlistbot at llvm.org
Fri Nov 4 10:48:37 PDT 2022


Author: Slava Zakharin
Date: 2022-11-04T10:48:19-07:00
New Revision: 095ce655ec84fc21b6002808c698687c37f2bf12

URL: https://github.com/llvm/llvm-project/commit/095ce655ec84fc21b6002808c698687c37f2bf12
DIFF: https://github.com/llvm/llvm-project/commit/095ce655ec84fc21b6002808c698687c37f2bf12.diff

LOG: [mlir][math] Simplify pow(x, 0.75) into sqrt(sqrt(x)) * sqrt(x).

Trivial simplification for CPU2017/503.bwaves resulting in 3.89%
speed-up on icelake.

Differential Revision: https://reviews.llvm.org/D137351

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
    mlir/test/Dialect/Math/algebraic-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index bea939a65022a..a1e6746b8fe9b 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -109,6 +109,15 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     return success();
   }
 
+  // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
+  if (isExponentValue(0.75)) {
+    Value pow_half = rewriter.create<math::SqrtOp>(op.getLoc(), x);
+    Value pow_quarter = rewriter.create<math::SqrtOp>(op.getLoc(), pow_half);
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(
+        op, ValueRange{pow_half, pow_quarter});
+    return success();
+  }
+
   return failure();
 }
 

diff  --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index 806779ad9198d..21c9f7a8e7f17 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -74,6 +74,22 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_0_75
+func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0
+  // CHECK: %[[SQRT2S:.*]] = math.sqrt %[[SQRT1S]]
+  // CHECK: %[[SCALAR:.*]] = arith.mulf %[[SQRT1S]], %[[SQRT2S]]
+  // CHECK: %[[SQRT1V:.*]] = math.sqrt %arg1
+  // CHECK: %[[SQRT2V:.*]] = math.sqrt %[[SQRT1V]]
+  // CHECK: %[[VECTOR:.*]] = arith.mulf %[[SQRT1V]], %[[SQRT2V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.75 : f32
+  %v = arith.constant dense <0.75> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @ipowi_zero_exp(
 // CHECK-SAME: %[[ARG0:.+]]: i32
 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>


        


More information about the Mlir-commits mailing list