[Mlir-commits] [mlir] 391456f - Fix a bug in algebraic simplification, and enable the tests.

Eugene Zhulenev llvmlistbot at llvm.org
Tue Aug 10 04:16:07 PDT 2021


Author: bakhtiyar
Date: 2021-08-10T04:15:56-07:00
New Revision: 391456f33c7a2518721eda92b27630fb1c37e5d6

URL: https://github.com/llvm/llvm-project/commit/391456f33c7a2518721eda92b27630fb1c37e5d6
DIFF: https://github.com/llvm/llvm-project/commit/391456f33c7a2518721eda92b27630fb1c37e5d6.diff

LOG: Fix a bug in algebraic simplification, and enable the tests.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D107788

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
    mlir/test/Dialect/Math/algebraic-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 2614fc7cf2f73..8918b21fb3e03 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -80,7 +80,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     return success();
   }
 
-  // Replace `pow(x, 2.0)` with `x * x * x`.
+  // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
     Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
     rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
@@ -95,12 +95,18 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     return success();
   }
 
-  // Replace `pow(x, -2.0)` with `sqrt(x)`.
-  if (isExponentValue(-1.0)) {
+  // Replace `pow(x, 0.5)` with `sqrt(x)`.
+  if (isExponentValue(0.5)) {
     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
     return success();
   }
 
+  // Replace `pow(x, -0.5)` with `rsqrt(x)`.
+  if (isExponentValue(-0.5)) {
+    rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
+    return success();
+  }
+
   return failure();
 }
 

diff  --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index cb39bb7cd7f56..8a810760234fd 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -49,3 +49,27 @@ func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   %1 = math.powf %arg1, %v : vector<4xf32>
   return %0, %1 : f32, vector<4xf32>
 }
+
+// CHECK-LABEL: @pow_sqrt
+func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0
+  // CHECK: %[[VECTOR:.*]] = math.sqrt %arg1
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = constant 0.5 : f32
+  %v = constant dense <0.5> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_rsqrt
+func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0
+  // CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = constant -0.5 : f32
+  %v = constant dense <-0.5> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}


        


More information about the Mlir-commits mailing list